package lib
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strings"
)
// InjectionDetection contains patterns that indicate potential LLM injection attacks
var injectionPatterns = []*regexp.Regexp{
regexp.MustCompile(`(?i)(ignore|forget|disregard)\s+(all|previous|your)\s+(instructions|prompts|training)`),
regexp.MustCompile(`(?i)system\s*:\s*you\s+are\s+now`),
regexp.MustCompile(`(?i)<\s*(system|assistant|user)\s*>`),
regexp.MustCompile(`(?i)\[\s*(system|assistant|user)\s*\]`),
regexp.MustCompile(`(?i)\{\s*(system|assistant|user)\s*\}`),
regexp.MustCompile(`(?i)you\s+are\s+(an?\s+)?(attacker|hacker|malicious)`),
regexp.MustCompile(`(?i)output\s*:\s*.*(?:password|secret|key|token)`),
regexp.MustCompile(`(?i)prompt\s*:\s*.*(?:override|bypass|ignore)`),
}
// SanitizeResult indicates the outcome of sanitization
type SanitizeResult struct {
Cleaned string
Violations []string
Blocked bool
}
// SanitizeUserInput sanitizes user input for LLM prompts (chat endpoint).
// Implements defense-in-depth: validation, injection detection, and message delimiters.
func SanitizeUserInput(input string, maxLen int) SanitizeResult {
result := SanitizeResult{Cleaned: input}
// Layer 1: Length enforcement
if maxLen <= 0 {
maxLen = 4000 // Default limit
}
if len(input) > maxLen {
result.Cleaned = input[:maxLen]
result.Violations = append(result.Violations, "input_truncated")
}
// Layer 2: Injection pattern detection
for _, pattern := range injectionPatterns {
if matches := pattern.FindAllString(result.Cleaned, -1); len(matches) > 0 {
result.Violations = append(result.Violations, "injection_pattern_detected")
result.Blocked = true
// Replace matches with safe placeholders
result.Cleaned = pattern.ReplaceAllString(result.Cleaned, "[FILTERED]")
}
}
// Layer 3: Escape XML-like delimiters that could confuse message boundaries
result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "<system>")
result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "</system>")
result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "<assistant>")
result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "</assistant>")
result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "<user>")
result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "</user>")
return result
}
// WrapUserMessage wraps user input in XML-style delimiters for clear separation from system instructions.
// This makes it harder for user input to be interpreted as system/assistant messages.
func WrapUserMessage(content string) string {
return fmt.Sprintf("\n%s\n", content)
}
// WrapSystemMessage wraps system content in XML-style delimiters.
func WrapSystemMessage(content string) string {
return fmt.Sprintf("\n%s\n", content)
}
// SanitizeHTMLContent prepares HTML content for LLM processing (scrape endpoint).
// Removes scripts, styles, and other potentially malicious content before LLM sees it.
func SanitizeHTMLContent(html string, maxLen int) string {
if maxLen <= 0 {
maxLen = 50000 // Default limit for HTML
}
// Layer 1: Remove scripts (both blocks entirely
scriptRegex := regexp.MustCompile(`(?is)`)
cleaned := scriptRegex.ReplaceAllString(html, "")
// Layer 2: Remove `)
cleaned = styleRegex.ReplaceAllString(cleaned, "")
// Layer 3: Remove onclick, onload, etc. event handlers
eventRegex := regexp.MustCompile(`(?i)\s+on\w+\s*=\s*"[^"]*"`)
cleaned = eventRegex.ReplaceAllString(cleaned, "")
eventRegex2 := regexp.MustCompile(`(?i)\s+on\w+\s*=\s*'[^']*'`)
cleaned = eventRegex2.ReplaceAllString(cleaned, "")
// Layer 4: Remove javascript: URLs
jsUrlRegex := regexp.MustCompile(`(?i)javascript\s*:\s*[^"'>\s]+`)
cleaned = jsUrlRegex.ReplaceAllString(cleaned, "")
// Layer 5: Remove data: URLs that could be SVG/script-based
dataUrlRegex := regexp.MustCompile(`(?i)data\s*:\s*[^"'>\s]+`)
cleaned = dataUrlRegex.ReplaceAllString(cleaned, "[data-url-removed]")
// Layer 6: Remove comments that might hide injection attempts
commentRegex := regexp.MustCompile(`(?s)`)
cleaned = commentRegex.ReplaceAllString(cleaned, "")
// Layer 7: Normalize whitespace to reduce hidden characters
cleaned = regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]`).ReplaceAllString(cleaned, "")
// Layer 8: Length enforcement
if len(cleaned) > maxLen {
cleaned = cleaned[:maxLen]
}
return cleaned
}
// BuildSafeChatMessages constructs a properly delimited message list for chat LLM calls.
// Takes raw user input and sanitizes it before inclusion in messages.
func BuildSafeChatMessages(systemPrompt string, history []map[string]string, userMessage string, maxUserLen int) ([]map[string]interface{}, []string) {
messages := []map[string]interface{}{
{"role": "system", "content": WrapSystemMessage(systemPrompt)},
}
// Process history through sanitization
for _, msg := range history {
role := msg["role"]
content := msg["content"]
// Sanitize based on role
if role == "user" {
result := SanitizeUserInput(content, maxUserLen)
content = WrapUserMessage(result.Cleaned)
}
// Assistant messages are trusted (came from our system)
messages = append(messages, map[string]interface{}{
"role": role,
"content": content,
})
}
// Sanitize current user message
result := SanitizeUserInput(userMessage, maxUserLen)
messages = append(messages, map[string]interface{}{
"role": "user",
"content": WrapUserMessage(result.Cleaned),
})
return messages, result.Violations
}
// BuildSafeScrapePrompt constructs a safe prompt with sanitized HTML content for scraping.
// The HTML is wrapped in delimiters to isolate it from instructions.
func BuildSafeScrapePrompt(instructions string, htmlContent string, domain string, maxHtmlLen int) string {
sanitizedHTML := SanitizeHTMLContent(htmlContent, maxHtmlLen)
return fmt.Sprintf(`%s
The following is sanitized HTML content from %s, enclosed in delimiters. Analyze this content only.
%s
`,
instructions, domain, sanitizedHTML)
}
// CallOpenRouter sends a request to OpenRouter (OpenAI-compatible API).
func CallOpenRouter(apiKey, model string, messages []map[string]interface{}, maxTokens int) (string, error) {
if apiKey == "" {
return "", fmt.Errorf("OpenRouter API key not configured")
}
reqBody := map[string]interface{}{
"model": model,
"messages": messages,
"max_tokens": maxTokens,
"temperature": 0.1,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequest("POST", "https://openrouter.ai/api/v1/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("API request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body))
}
var oaiResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &oaiResp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(oaiResp.Choices) == 0 {
return "", fmt.Errorf("empty response from OpenRouter")
}
text := strings.TrimSpace(oaiResp.Choices[0].Message.Content)
text = strings.TrimPrefix(text, "```json")
text = strings.TrimPrefix(text, "```")
text = strings.TrimSuffix(text, "```")
return strings.TrimSpace(text), nil
}