233 lines
7.9 KiB
Go
233 lines
7.9 KiB
Go
package lib
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// InjectionDetection contains patterns that indicate potential LLM injection attacks
|
|
var injectionPatterns = []*regexp.Regexp{
|
|
regexp.MustCompile(`(?i)(ignore|forget|disregard)\s+(all|previous|your)\s+(instructions|prompts|training)`),
|
|
regexp.MustCompile(`(?i)system\s*:\s*you\s+are\s+now`),
|
|
regexp.MustCompile(`(?i)<\s*(system|assistant|user)\s*>`),
|
|
regexp.MustCompile(`(?i)\[\s*(system|assistant|user)\s*\]`),
|
|
regexp.MustCompile(`(?i)\{\s*(system|assistant|user)\s*\}`),
|
|
regexp.MustCompile(`(?i)you\s+are\s+(an?\s+)?(attacker|hacker|malicious)`),
|
|
regexp.MustCompile(`(?i)output\s*:\s*.*(?:password|secret|key|token)`),
|
|
regexp.MustCompile(`(?i)prompt\s*:\s*.*(?:override|bypass|ignore)`),
|
|
}
|
|
|
|
// SanitizeResult indicates the outcome of sanitization
|
|
type SanitizeResult struct {
|
|
Cleaned string
|
|
Violations []string
|
|
Blocked bool
|
|
}
|
|
|
|
// SanitizeUserInput sanitizes user input for LLM prompts (chat endpoint).
|
|
// Implements defense-in-depth: validation, injection detection, and message delimiters.
|
|
func SanitizeUserInput(input string, maxLen int) SanitizeResult {
|
|
result := SanitizeResult{Cleaned: input}
|
|
|
|
// Layer 1: Length enforcement
|
|
if maxLen <= 0 {
|
|
maxLen = 4000 // Default limit
|
|
}
|
|
if len(input) > maxLen {
|
|
result.Cleaned = input[:maxLen]
|
|
result.Violations = append(result.Violations, "input_truncated")
|
|
}
|
|
|
|
// Layer 2: Injection pattern detection
|
|
for _, pattern := range injectionPatterns {
|
|
if matches := pattern.FindAllString(result.Cleaned, -1); len(matches) > 0 {
|
|
result.Violations = append(result.Violations, "injection_pattern_detected")
|
|
result.Blocked = true
|
|
// Replace matches with safe placeholders
|
|
result.Cleaned = pattern.ReplaceAllString(result.Cleaned, "[FILTERED]")
|
|
}
|
|
}
|
|
|
|
// Layer 3: Escape XML-like delimiters that could confuse message boundaries
|
|
result.Cleaned = strings.ReplaceAll(result.Cleaned, "<system>", "<system>")
|
|
result.Cleaned = strings.ReplaceAll(result.Cleaned, "</system>", "</system>")
|
|
result.Cleaned = strings.ReplaceAll(result.Cleaned, "<assistant>", "<assistant>")
|
|
result.Cleaned = strings.ReplaceAll(result.Cleaned, "</assistant>", "</assistant>")
|
|
result.Cleaned = strings.ReplaceAll(result.Cleaned, "<user>", "<user>")
|
|
result.Cleaned = strings.ReplaceAll(result.Cleaned, "</user>", "</user>")
|
|
|
|
return result
|
|
}
|
|
|
|
// WrapUserMessage wraps user input in XML-style delimiters for clear separation from system instructions.
|
|
// This makes it harder for user input to be interpreted as system/assistant messages.
|
|
func WrapUserMessage(content string) string {
|
|
return fmt.Sprintf("<user_message>\n%s\n</user_message>", content)
|
|
}
|
|
|
|
// WrapSystemMessage wraps system content in XML-style delimiters.
|
|
func WrapSystemMessage(content string) string {
|
|
return fmt.Sprintf("<system_instructions>\n%s\n</system_instructions>", content)
|
|
}
|
|
|
|
// SanitizeHTMLContent prepares HTML content for LLM processing (scrape endpoint).
|
|
// Removes scripts, styles, and other potentially malicious content before LLM sees it.
|
|
func SanitizeHTMLContent(html string, maxLen int) string {
|
|
if maxLen <= 0 {
|
|
maxLen = 50000 // Default limit for HTML
|
|
}
|
|
|
|
// Layer 1: Remove scripts (both <script> tags and event handlers)
|
|
// Remove <script>...</script> blocks entirely
|
|
scriptRegex := regexp.MustCompile(`(?is)<script[^>]*>.*?</script>`)
|
|
cleaned := scriptRegex.ReplaceAllString(html, "")
|
|
|
|
// Layer 2: Remove <style> tags
|
|
styleRegex := regexp.MustCompile(`(?is)<style[^>]*>.*?</style>`)
|
|
cleaned = styleRegex.ReplaceAllString(cleaned, "")
|
|
|
|
// Layer 3: Remove onclick, onload, etc. event handlers
|
|
eventRegex := regexp.MustCompile(`(?i)\s+on\w+\s*=\s*"[^"]*"`)
|
|
cleaned = eventRegex.ReplaceAllString(cleaned, "")
|
|
|
|
eventRegex2 := regexp.MustCompile(`(?i)\s+on\w+\s*=\s*'[^']*'`)
|
|
cleaned = eventRegex2.ReplaceAllString(cleaned, "")
|
|
|
|
// Layer 4: Remove javascript: URLs
|
|
jsUrlRegex := regexp.MustCompile(`(?i)javascript\s*:\s*[^"'>\s]+`)
|
|
cleaned = jsUrlRegex.ReplaceAllString(cleaned, "")
|
|
|
|
// Layer 5: Remove data: URLs that could be SVG/script-based
|
|
dataUrlRegex := regexp.MustCompile(`(?i)data\s*:\s*[^"'>\s]+`)
|
|
cleaned = dataUrlRegex.ReplaceAllString(cleaned, "[data-url-removed]")
|
|
|
|
// Layer 6: Remove comments that might hide injection attempts
|
|
commentRegex := regexp.MustCompile(`(?s)<!--.*?-->`)
|
|
cleaned = commentRegex.ReplaceAllString(cleaned, "")
|
|
|
|
// Layer 7: Normalize whitespace to reduce hidden characters
|
|
cleaned = regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]`).ReplaceAllString(cleaned, "")
|
|
|
|
// Layer 8: Length enforcement
|
|
if len(cleaned) > maxLen {
|
|
cleaned = cleaned[:maxLen]
|
|
}
|
|
|
|
return cleaned
|
|
}
|
|
|
|
// BuildSafeChatMessages constructs a properly delimited message list for chat LLM calls.
|
|
// Takes raw user input and sanitizes it before inclusion in messages.
|
|
func BuildSafeChatMessages(systemPrompt string, history []map[string]string, userMessage string, maxUserLen int) ([]map[string]interface{}, []string) {
|
|
messages := []map[string]interface{}{
|
|
{"role": "system", "content": WrapSystemMessage(systemPrompt)},
|
|
}
|
|
|
|
// Process history through sanitization
|
|
for _, msg := range history {
|
|
role := msg["role"]
|
|
content := msg["content"]
|
|
|
|
// Sanitize based on role
|
|
if role == "user" {
|
|
result := SanitizeUserInput(content, maxUserLen)
|
|
content = WrapUserMessage(result.Cleaned)
|
|
}
|
|
// Assistant messages are trusted (came from our system)
|
|
|
|
messages = append(messages, map[string]interface{}{
|
|
"role": role,
|
|
"content": content,
|
|
})
|
|
}
|
|
|
|
// Sanitize current user message
|
|
result := SanitizeUserInput(userMessage, maxUserLen)
|
|
messages = append(messages, map[string]interface{}{
|
|
"role": "user",
|
|
"content": WrapUserMessage(result.Cleaned),
|
|
})
|
|
|
|
return messages, result.Violations
|
|
}
|
|
|
|
// BuildSafeScrapePrompt constructs a safe prompt with sanitized HTML content for scraping.
|
|
// The HTML is wrapped in delimiters to isolate it from instructions.
|
|
func BuildSafeScrapePrompt(instructions string, htmlContent string, domain string, maxHtmlLen int) string {
|
|
sanitizedHTML := SanitizeHTMLContent(htmlContent, maxHtmlLen)
|
|
|
|
return fmt.Sprintf(`%s
|
|
|
|
The following is sanitized HTML content from %s, enclosed in delimiters. Analyze this content only.
|
|
|
|
<scraped_html>
|
|
%s
|
|
</scraped_html>`,
|
|
instructions, domain, sanitizedHTML)
|
|
}
|
|
|
|
// CallOpenRouter sends a request to OpenRouter (OpenAI-compatible API).
|
|
func CallOpenRouter(apiKey, model string, messages []map[string]interface{}, maxTokens int) (string, error) {
|
|
if apiKey == "" {
|
|
return "", fmt.Errorf("OpenRouter API key not configured")
|
|
}
|
|
|
|
reqBody := map[string]interface{}{
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": maxTokens,
|
|
"temperature": 0.1,
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", "https://openrouter.ai/api/v1/chat/completions", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return "", fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("API request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read response: %w", err)
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var oaiResp struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
}
|
|
if err := json.Unmarshal(body, &oaiResp); err != nil {
|
|
return "", fmt.Errorf("parse response: %w", err)
|
|
}
|
|
if len(oaiResp.Choices) == 0 {
|
|
return "", fmt.Errorf("empty response from OpenRouter")
|
|
}
|
|
|
|
text := strings.TrimSpace(oaiResp.Choices[0].Message.Content)
|
|
text = strings.TrimPrefix(text, "```json")
|
|
text = strings.TrimPrefix(text, "```")
|
|
text = strings.TrimSuffix(text, "```")
|
|
return strings.TrimSpace(text), nil
|
|
}
|