453 lines
14 KiB
Go
453 lines
14 KiB
Go
// Package proxy implements an HTTPS MITM proxy with LLM-based policy evaluation.
|
|
//
|
|
// Architecture:
|
|
// - Agent sets HTTP_PROXY=http://localhost:19840 (or configured port)
|
|
// - For plain HTTP: proxy injects Authorization/headers, forwards
|
|
// - For HTTPS: proxy performs CONNECT tunnel, generates per-host TLS cert (signed by local CA)
|
|
// - Before injecting credentials: optional LLM policy evaluation (intent check)
|
|
//
|
|
// Credential injection:
|
|
// - Scans request for placeholder patterns: {{clavitor.entry_title.field_label}}
|
|
// - Also injects via per-host credential rules stored in vault
|
|
// - Tier check: L2 fields are never injected (identity/card data)
|
|
package proxy
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
// Config holds proxy configuration.
|
|
type Config struct {
|
|
// ListenAddr is the proxy listen address, e.g. "127.0.0.1:19840"
|
|
ListenAddr string
|
|
|
|
// DataDir is the vault data directory (for CA cert/key storage)
|
|
DataDir string
|
|
|
|
// VaultKey is the L1 decryption key (to read credentials for injection)
|
|
VaultKey []byte
|
|
|
|
// DBPath is path to the vault SQLite database
|
|
DBPath string
|
|
|
|
// LLMEnabled enables LLM-based intent evaluation before credential injection
|
|
LLMEnabled bool
|
|
|
|
// LLMBaseURL is the LLM API base URL (OpenAI-compatible)
|
|
LLMBaseURL string
|
|
|
|
// LLMAPIKey is the API key for LLM requests
|
|
LLMAPIKey string
|
|
|
|
// LLMModel is the model to use for policy evaluation
|
|
LLMModel string
|
|
}
|
|
|
|
// Proxy is the MITM proxy server.
|
|
type Proxy struct {
|
|
cfg Config
|
|
ca *tls.Certificate
|
|
caCert *x509.Certificate
|
|
caKey *rsa.PrivateKey
|
|
certMu sync.Mutex
|
|
certs map[string]*tls.Certificate // hostname → generated cert (cache)
|
|
}
|
|
|
|
// New creates a new Proxy. Generates or loads the CA cert from DataDir.
|
|
func New(cfg Config) (*Proxy, error) {
|
|
p := &Proxy{
|
|
cfg: cfg,
|
|
certs: make(map[string]*tls.Certificate),
|
|
}
|
|
if err := p.loadOrCreateCA(); err != nil {
|
|
return nil, fmt.Errorf("proxy CA: %w", err)
|
|
}
|
|
return p, nil
|
|
}
|
|
|
|
// ListenAndServe starts the proxy server. Blocks until stopped.
|
|
func (p *Proxy) ListenAndServe() error {
|
|
ln, err := net.Listen("tcp", p.cfg.ListenAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("proxy listen %s: %w", p.cfg.ListenAddr, err)
|
|
}
|
|
log.Printf("proxy: listening on %s (LLM policy: %v)", p.cfg.ListenAddr, p.cfg.LLMEnabled)
|
|
srv := &http.Server{
|
|
Handler: p,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
}
|
|
return srv.Serve(ln)
|
|
}
|
|
|
|
// ServeHTTP handles all incoming proxy requests.
|
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodConnect {
|
|
p.handleCONNECT(w, r)
|
|
return
|
|
}
|
|
p.handleHTTP(w, r)
|
|
}
|
|
|
|
// handleHTTP handles plain HTTP proxy requests.
|
|
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
|
|
// Remove proxy-specific headers
|
|
r.RequestURI = ""
|
|
r.Header.Del("Proxy-Connection")
|
|
r.Header.Del("Proxy-Authenticate")
|
|
r.Header.Del("Proxy-Authorization")
|
|
|
|
// Inject credentials if applicable
|
|
if err := p.injectCredentials(r); err != nil {
|
|
log.Printf("proxy: credential injection error for %s: %v", r.URL.Host, err)
|
|
// Non-fatal: continue without injection
|
|
}
|
|
|
|
// Forward the request
|
|
rp := &httputil.ReverseProxy{
|
|
Director: func(req *http.Request) {},
|
|
}
|
|
rp.ServeHTTP(w, r)
|
|
}
|
|
|
|
// handleCONNECT handles HTTPS CONNECT tunnel requests.
|
|
func (p *Proxy) handleCONNECT(w http.ResponseWriter, r *http.Request) {
|
|
host := r.Host
|
|
if !strings.Contains(host, ":") {
|
|
host = host + ":443"
|
|
}
|
|
hostname, _, _ := net.SplitHostPort(host)
|
|
|
|
// Acknowledge the CONNECT
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
// Hijack the connection
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
log.Printf("proxy: CONNECT hijack not supported")
|
|
return
|
|
}
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
log.Printf("proxy: CONNECT hijack error: %v", err)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
// Generate a certificate for this hostname
|
|
cert, err := p.certForHost(hostname)
|
|
if err != nil {
|
|
log.Printf("proxy: cert generation failed for %s: %v", hostname, err)
|
|
return
|
|
}
|
|
|
|
// Wrap client connection in TLS (using our MITM cert)
|
|
tlsCfg := &tls.Config{
|
|
Certificates: []tls.Certificate{*cert},
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
tlsClientConn := tls.Server(clientConn, tlsCfg)
|
|
defer tlsClientConn.Close()
|
|
if err := tlsClientConn.Handshake(); err != nil {
|
|
log.Printf("proxy: TLS handshake failed for %s: %v", hostname, err)
|
|
return
|
|
}
|
|
|
|
// Connect to real upstream
|
|
upstreamConn, err := tls.Dial("tcp", host, &tls.Config{
|
|
ServerName: hostname,
|
|
MinVersion: tls.VersionTLS12,
|
|
})
|
|
if err != nil {
|
|
log.Printf("proxy: upstream dial failed for %s: %v", host, err)
|
|
return
|
|
}
|
|
defer upstreamConn.Close()
|
|
|
|
// Intercept HTTP traffic between client and upstream
|
|
p.interceptHTTP(tlsClientConn, upstreamConn, hostname)
|
|
}
|
|
|
|
// interceptHTTP reads HTTP requests from the client, injects credentials, forwards to upstream.
|
|
func (p *Proxy) interceptHTTP(clientConn net.Conn, upstreamConn net.Conn, hostname string) {
|
|
// Use Go's http.ReadRequest to parse the client's request
|
|
clientReader := newBufReader(clientConn)
|
|
|
|
for {
|
|
req, err := http.ReadRequest(clientReader)
|
|
if err != nil {
|
|
if err != io.EOF {
|
|
log.Printf("proxy: read request error for %s: %v", hostname, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Set the correct URL for upstream forwarding
|
|
req.URL.Scheme = "https"
|
|
req.URL.Host = hostname
|
|
req.RequestURI = ""
|
|
|
|
// Inject credentials
|
|
if err := p.injectCredentials(req); err != nil {
|
|
log.Printf("proxy: credential injection error for %s: %v", hostname, err)
|
|
}
|
|
|
|
// Forward to upstream
|
|
if err := req.Write(upstreamConn); err != nil {
|
|
log.Printf("proxy: upstream write error for %s: %v", hostname, err)
|
|
return
|
|
}
|
|
|
|
// Read upstream response and forward to client
|
|
upstreamReader := newBufReader(upstreamConn)
|
|
resp, err := http.ReadResponse(upstreamReader, req)
|
|
if err != nil {
|
|
log.Printf("proxy: upstream read error for %s: %v", hostname, err)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if err := resp.Write(clientConn); err != nil {
|
|
log.Printf("proxy: client write error for %s: %v", hostname, err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// injectCredentials scans the request for credential placeholders and injects them.
|
|
// Placeholder format: {{clavitor.entry_title.field_label}} in headers, URL, or body.
|
|
// Also applies host-based automatic injection rules from vault.
|
|
// L2 (identity/card) fields are NEVER injected.
|
|
func (p *Proxy) injectCredentials(r *http.Request) error {
|
|
if p.cfg.VaultKey == nil {
|
|
return nil // No vault key — skip injection
|
|
}
|
|
|
|
// Check for LLM policy evaluation
|
|
if p.cfg.LLMEnabled {
|
|
allowed, reason, err := p.evaluatePolicy(r)
|
|
if err != nil {
|
|
log.Printf("proxy: LLM policy eval error: %v (allowing)", err)
|
|
} else if !allowed {
|
|
log.Printf("proxy: LLM policy DENIED %s %s: %s", r.Method, r.URL, reason)
|
|
return fmt.Errorf("policy denied: %s", reason)
|
|
}
|
|
}
|
|
|
|
// TODO: Implement placeholder substitution once vault DB integration is wired in.
|
|
// Pattern: scan r.Header values, r.URL, r.Body for {{clavitor.TITLE.FIELD}}
|
|
// Lookup entry by title (case-insensitive), get field by label, verify Tier != L2
|
|
// Replace placeholder with decrypted field value.
|
|
//
|
|
// Auto-injection (host rules):
|
|
// Vault entries can specify "proxy_inject_hosts": ["api.github.com"] in metadata
|
|
// When a request matches, inject the entry's L1 fields as headers per a configured map.
|
|
//
|
|
// This stub returns nil — no injection until DB wiring is complete.
|
|
return nil
|
|
}
|
|
|
|
// evaluatePolicy calls the configured LLM to evaluate whether this request
|
|
// is consistent with the expected behavior of an AI agent (vs. exfiltration/abuse).
|
|
func (p *Proxy) evaluatePolicy(r *http.Request) (allowed bool, reason string, err error) {
|
|
if p.cfg.LLMBaseURL == "" || p.cfg.LLMAPIKey == "" {
|
|
return true, "LLM not configured", nil
|
|
}
|
|
|
|
// Build a concise request summary for the LLM
|
|
summary := fmt.Sprintf("Method: %s\nHost: %s\nPath: %s\nContent-Type: %s",
|
|
r.Method, r.Host, r.URL.Path,
|
|
r.Header.Get("Content-Type"))
|
|
|
|
prompt := `You are a security policy evaluator for an AI agent credential proxy.
|
|
|
|
The following outbound HTTP request is about to have credentials injected and be forwarded.
|
|
Evaluate whether this request is consistent with normal AI agent behavior (coding, API calls, deployment)
|
|
vs. suspicious activity (credential exfiltration, unexpected destinations, data harvesting).
|
|
|
|
Request summary:
|
|
` + summary + `
|
|
|
|
Respond with JSON only: {"allowed": true/false, "reason": "one sentence"}`
|
|
|
|
_ = prompt // Used when LLM call is implemented below
|
|
|
|
// TODO: Implement actual LLM call using cfg.LLMBaseURL + cfg.LLMAPIKey + cfg.LLMModel
|
|
// For now: always allow (policy eval is opt-in, not blocking by default)
|
|
// Real implementation: POST to /v1/chat/completions, parse JSON response
|
|
return true, "policy evaluation not yet implemented", nil
|
|
}
|
|
|
|
// certForHost returns a TLS certificate for the given hostname, generating one if needed.
|
|
func (p *Proxy) certForHost(hostname string) (*tls.Certificate, error) {
|
|
p.certMu.Lock()
|
|
defer p.certMu.Unlock()
|
|
|
|
if cert, ok := p.certs[hostname]; ok {
|
|
// Check if cert is still valid (> 1 hour remaining)
|
|
if time.Until(cert.Leaf.NotAfter) > time.Hour {
|
|
return cert, nil
|
|
}
|
|
}
|
|
|
|
// Generate a new cert signed by our CA
|
|
cert, err := p.generateCert(hostname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
p.certs[hostname] = cert
|
|
return cert, nil
|
|
}
|
|
|
|
// generateCert generates a TLS cert for the given hostname, signed by the proxy CA.
|
|
func (p *Proxy) generateCert(hostname string) (*tls.Certificate, error) {
|
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate key: %w", err)
|
|
}
|
|
|
|
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: serial,
|
|
Subject: pkix.Name{CommonName: hostname},
|
|
DNSNames: []string{hostname},
|
|
NotBefore: time.Now().Add(-time.Minute),
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
|
}
|
|
|
|
// Add IP SAN if hostname is an IP
|
|
if ip := net.ParseIP(hostname); ip != nil {
|
|
tmpl.IPAddresses = []net.IP{ip}
|
|
tmpl.DNSNames = nil
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, p.caCert, &key.PublicKey, p.caKey)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create cert: %w", err)
|
|
}
|
|
|
|
leaf, err := x509.ParseCertificate(certDER)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse cert: %w", err)
|
|
}
|
|
|
|
tlsCert := &tls.Certificate{
|
|
Certificate: [][]byte{certDER},
|
|
PrivateKey: key,
|
|
Leaf: leaf,
|
|
}
|
|
return tlsCert, nil
|
|
}
|
|
|
|
// loadOrCreateCA loads the proxy CA cert/key from DataDir, or generates new ones.
|
|
func (p *Proxy) loadOrCreateCA() error {
|
|
caDir := filepath.Join(p.cfg.DataDir, "proxy")
|
|
if err := os.MkdirAll(caDir, 0700); err != nil {
|
|
return err
|
|
}
|
|
certPath := filepath.Join(caDir, "ca.crt")
|
|
keyPath := filepath.Join(caDir, "ca.key")
|
|
|
|
// Try to load existing CA
|
|
if _, err := os.Stat(certPath); err == nil {
|
|
certPEM, err := os.ReadFile(certPath)
|
|
if err != nil {
|
|
return fmt.Errorf("read CA cert: %w", err)
|
|
}
|
|
keyPEM, err := os.ReadFile(keyPath)
|
|
if err != nil {
|
|
return fmt.Errorf("read CA key: %w", err)
|
|
}
|
|
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
|
if err != nil {
|
|
return fmt.Errorf("parse CA keypair: %w", err)
|
|
}
|
|
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
|
if err != nil {
|
|
return fmt.Errorf("parse CA cert: %w", err)
|
|
}
|
|
// Check expiry — regenerate if < 7 days left
|
|
if time.Until(tlsCert.Leaf.NotAfter) < 7*24*time.Hour {
|
|
log.Printf("proxy: CA cert expires soon (%s), regenerating", tlsCert.Leaf.NotAfter.Format("2006-01-02"))
|
|
} else {
|
|
p.ca = &tlsCert
|
|
p.caCert = tlsCert.Leaf
|
|
p.caKey = tlsCert.PrivateKey.(*rsa.PrivateKey)
|
|
log.Printf("proxy: loaded CA cert (expires %s)", tlsCert.Leaf.NotAfter.Format("2006-01-02"))
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// Generate new CA
|
|
log.Printf("proxy: generating new CA cert...")
|
|
key, err := rsa.GenerateKey(rand.Reader, 4096)
|
|
if err != nil {
|
|
return fmt.Errorf("generate CA key: %w", err)
|
|
}
|
|
|
|
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: serial,
|
|
Subject: pkix.Name{CommonName: "Clavitor Proxy CA", Organization: []string{"Clavitor"}},
|
|
NotBefore: time.Now().Add(-time.Minute),
|
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
MaxPathLen: 0,
|
|
}
|
|
|
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
|
if err != nil {
|
|
return fmt.Errorf("create CA cert: %w", err)
|
|
}
|
|
leaf, _ := x509.ParseCertificate(certDER)
|
|
|
|
// Write to disk
|
|
certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return fmt.Errorf("write CA cert: %w", err)
|
|
}
|
|
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
|
certFile.Close()
|
|
|
|
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
|
if err != nil {
|
|
return fmt.Errorf("write CA key: %w", err)
|
|
}
|
|
pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
|
keyFile.Close()
|
|
|
|
p.ca = &tls.Certificate{Certificate: [][]byte{certDER}, PrivateKey: key, Leaf: leaf}
|
|
p.caCert = leaf
|
|
p.caKey = key
|
|
|
|
log.Printf("proxy: CA cert generated at %s (install in OS trust store or pass --proxy-ca)", certPath)
|
|
log.Printf("proxy: CA cert path: %s", certPath)
|
|
return nil
|
|
}
|
|
|
|
// CACertPath returns the path to the proxy CA certificate (for user installation).
|
|
func (p *Proxy) CACertPath() string {
|
|
return filepath.Join(p.cfg.DataDir, "proxy", "ca.crt")
|
|
}
|