package lib import ( "bytes" "encoding/json" "fmt" "io" "net/http" "regexp" "strings" ) // InjectionDetection contains patterns that indicate potential LLM injection attacks var injectionPatterns = []*regexp.Regexp{ regexp.MustCompile(`(?i)(ignore|forget|disregard)\s+(all|previous|your)\s+(instructions|prompts|training)`), regexp.MustCompile(`(?i)system\s*:\s*you\s+are\s+now`), regexp.MustCompile(`(?i)<\s*(system|assistant|user)\s*>`), regexp.MustCompile(`(?i)\[\s*(system|assistant|user)\s*\]`), regexp.MustCompile(`(?i)\{\s*(system|assistant|user)\s*\}`), regexp.MustCompile(`(?i)you\s+are\s+(an?\s+)?(attacker|hacker|malicious)`), regexp.MustCompile(`(?i)output\s*:\s*.*(?:password|secret|key|token)`), regexp.MustCompile(`(?i)prompt\s*:\s*.*(?:override|bypass|ignore)`), } // SanitizeResult indicates the outcome of sanitization type SanitizeResult struct { Cleaned string Violations []string Blocked bool } // SanitizeUserInput sanitizes user input for LLM prompts (chat endpoint). // Implements defense-in-depth: validation, injection detection, and message delimiters. func SanitizeUserInput(input string, maxLen int) SanitizeResult { result := SanitizeResult{Cleaned: input} // Layer 1: Length enforcement if maxLen <= 0 { maxLen = 4000 // Default limit } if len(input) > maxLen { result.Cleaned = input[:maxLen] result.Violations = append(result.Violations, "input_truncated") } // Layer 2: Injection pattern detection for _, pattern := range injectionPatterns { if matches := pattern.FindAllString(result.Cleaned, -1); len(matches) > 0 { result.Violations = append(result.Violations, "injection_pattern_detected") result.Blocked = true // Replace matches with safe placeholders result.Cleaned = pattern.ReplaceAllString(result.Cleaned, "[FILTERED]") } } // Layer 3: Escape XML-like delimiters that could confuse message boundaries result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "<system>") result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "</system>") result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "<assistant>") result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "</assistant>") result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "<user>") result.Cleaned = strings.ReplaceAll(result.Cleaned, "", "</user>") return result } // WrapUserMessage wraps user input in XML-style delimiters for clear separation from system instructions. // This makes it harder for user input to be interpreted as system/assistant messages. func WrapUserMessage(content string) string { return fmt.Sprintf("\n%s\n", content) } // WrapSystemMessage wraps system content in XML-style delimiters. func WrapSystemMessage(content string) string { return fmt.Sprintf("\n%s\n", content) } // SanitizeHTMLContent prepares HTML content for LLM processing (scrape endpoint). // Removes scripts, styles, and other potentially malicious content before LLM sees it. func SanitizeHTMLContent(html string, maxLen int) string { if maxLen <= 0 { maxLen = 50000 // Default limit for HTML } // Layer 1: Remove scripts (both blocks entirely scriptRegex := regexp.MustCompile(`(?is)]*>.*?`) cleaned := scriptRegex.ReplaceAllString(html, "") // Layer 2: Remove `) cleaned = styleRegex.ReplaceAllString(cleaned, "") // Layer 3: Remove onclick, onload, etc. event handlers eventRegex := regexp.MustCompile(`(?i)\s+on\w+\s*=\s*"[^"]*"`) cleaned = eventRegex.ReplaceAllString(cleaned, "") eventRegex2 := regexp.MustCompile(`(?i)\s+on\w+\s*=\s*'[^']*'`) cleaned = eventRegex2.ReplaceAllString(cleaned, "") // Layer 4: Remove javascript: URLs jsUrlRegex := regexp.MustCompile(`(?i)javascript\s*:\s*[^"'>\s]+`) cleaned = jsUrlRegex.ReplaceAllString(cleaned, "") // Layer 5: Remove data: URLs that could be SVG/script-based dataUrlRegex := regexp.MustCompile(`(?i)data\s*:\s*[^"'>\s]+`) cleaned = dataUrlRegex.ReplaceAllString(cleaned, "[data-url-removed]") // Layer 6: Remove comments that might hide injection attempts commentRegex := regexp.MustCompile(`(?s)`) cleaned = commentRegex.ReplaceAllString(cleaned, "") // Layer 7: Normalize whitespace to reduce hidden characters cleaned = regexp.MustCompile(`[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]`).ReplaceAllString(cleaned, "") // Layer 8: Length enforcement if len(cleaned) > maxLen { cleaned = cleaned[:maxLen] } return cleaned } // BuildSafeChatMessages constructs a properly delimited message list for chat LLM calls. // Takes raw user input and sanitizes it before inclusion in messages. func BuildSafeChatMessages(systemPrompt string, history []map[string]string, userMessage string, maxUserLen int) ([]map[string]interface{}, []string) { messages := []map[string]interface{}{ {"role": "system", "content": WrapSystemMessage(systemPrompt)}, } // Process history through sanitization for _, msg := range history { role := msg["role"] content := msg["content"] // Sanitize based on role if role == "user" { result := SanitizeUserInput(content, maxUserLen) content = WrapUserMessage(result.Cleaned) } // Assistant messages are trusted (came from our system) messages = append(messages, map[string]interface{}{ "role": role, "content": content, }) } // Sanitize current user message result := SanitizeUserInput(userMessage, maxUserLen) messages = append(messages, map[string]interface{}{ "role": "user", "content": WrapUserMessage(result.Cleaned), }) return messages, result.Violations } // BuildSafeScrapePrompt constructs a safe prompt with sanitized HTML content for scraping. // The HTML is wrapped in delimiters to isolate it from instructions. func BuildSafeScrapePrompt(instructions string, htmlContent string, domain string, maxHtmlLen int) string { sanitizedHTML := SanitizeHTMLContent(htmlContent, maxHtmlLen) return fmt.Sprintf(`%s The following is sanitized HTML content from %s, enclosed in delimiters. Analyze this content only. %s `, instructions, domain, sanitizedHTML) } // CallOpenRouter sends a request to OpenRouter (OpenAI-compatible API). func CallOpenRouter(apiKey, model string, messages []map[string]interface{}, maxTokens int) (string, error) { if apiKey == "" { return "", fmt.Errorf("OpenRouter API key not configured") } reqBody := map[string]interface{}{ "model": model, "messages": messages, "max_tokens": maxTokens, "temperature": 0.1, } jsonBody, err := json.Marshal(reqBody) if err != nil { return "", fmt.Errorf("marshal request: %w", err) } req, err := http.NewRequest("POST", "https://openrouter.ai/api/v1/chat/completions", bytes.NewReader(jsonBody)) if err != nil { return "", fmt.Errorf("create request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+apiKey) resp, err := http.DefaultClient.Do(req) if err != nil { return "", fmt.Errorf("API request: %w", err) } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", fmt.Errorf("read response: %w", err) } if resp.StatusCode != 200 { return "", fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body)) } var oaiResp struct { Choices []struct { Message struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` } if err := json.Unmarshal(body, &oaiResp); err != nil { return "", fmt.Errorf("parse response: %w", err) } if len(oaiResp.Choices) == 0 { return "", fmt.Errorf("empty response from OpenRouter") } text := strings.TrimSpace(oaiResp.Choices[0].Message.Content) text = strings.TrimPrefix(text, "```json") text = strings.TrimPrefix(text, "```") text = strings.TrimSuffix(text, "```") return strings.TrimSpace(text), nil }