inou/lib/llm.go

258 lines
7.5 KiB
Go

package lib
import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
var promptsDir string
// InitPrompts sets the directory where tracker files are located.
// This must be called by the main application at startup.
func InitPrompts(path string) {
promptsDir = path
}
// TrackerPromptsDir returns the configured trackers directory.
// This is used by local tracker loading functions in consumer packages.
func TrackerPromptsDir() string {
return promptsDir
}
// GeminiPart represents a single part in the Gemini content, can be text or inline_data.
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
}
// GeminiInlineData represents inline data for multimodal input.
type GeminiInlineData struct {
MimeType string `json:"mime_type"`
Data string `json:"data"`
}
// GeminiConfig allows overriding default generation parameters.
type GeminiConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
ResponseMimeType *string `json:"responseMimeType,omitempty"`
Model *string `json:"model,omitempty"` // Allows specifying a different model
}
// CallGemini sends a text-only request to the Gemini API with default configuration.
// It returns the raw text response from the model.
func CallGemini(prompt string) (string, error) {
parts := []GeminiPart{
{Text: prompt},
}
return CallGeminiMultimodal(parts, nil)
}
// CallGeminiMultimodal sends a request to the Gemini API with flexible content and configuration.
// It returns the raw text response from the model.
func CallGeminiMultimodal(parts []GeminiPart, config *GeminiConfig) (string, error) {
if GeminiKey == "" {
return "", fmt.Errorf("Gemini API key not configured")
}
// Default configuration
defaultTemperature := 0.1
defaultMaxOutputTokens := 2048
defaultResponseMimeType := "application/json"
defaultModel := "gemini-2.0-flash"
if config == nil {
config = &GeminiConfig{}
}
// Apply defaults if not overridden
if config.Temperature == nil {
config.Temperature = &defaultTemperature
}
if config.MaxOutputTokens == nil {
config.MaxOutputTokens = &defaultMaxOutputTokens
}
if config.ResponseMimeType == nil {
config.ResponseMimeType = &defaultResponseMimeType
}
if config.Model == nil {
config.Model = &defaultModel
}
reqBody := map[string]interface{}{
"contents": []map[string]interface{}{
{
"parts": parts,
},
},
"systemInstruction": map[string]interface{}{
"parts": []map[string]string{
{"text": "You are a JSON-only API. Output raw JSON with no markdown, no code fences, no explanations. Start directly with { and end with }."},
},
},
"generationConfig": map[string]interface{}{
"temperature": *config.Temperature,
"maxOutputTokens": *config.MaxOutputTokens,
"responseMimeType": *config.ResponseMimeType,
},
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", *config.Model, GeminiKey)
resp, err := http.Post(url, "application/json", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("API request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("Gemini API error %d: %s", resp.StatusCode, string(body))
}
var geminiResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
if err := json.Unmarshal(body, &geminiResp); err != nil {
return "", fmt.Errorf("failed to parse Gemini response: %v", err)
}
if len(geminiResp.Candidates) == 0 || len(geminiResp.Candidates[0].Content.Parts) == 0 {
return "", fmt.Errorf("empty response from Gemini")
}
finalText := strings.TrimSpace(geminiResp.Candidates[0].Content.Parts[0].Text)
// The model sometimes still wraps the output in markdown, so we clean it.
finalText = strings.TrimPrefix(finalText, "```json")
finalText = strings.TrimPrefix(finalText, "```")
finalText = strings.TrimSuffix(finalText, "```")
return finalText, nil
}
// CallFireworks sends a request to the Fireworks AI API (OpenAI-compatible).
// messages should be OpenAI-format: []map[string]interface{} with "role" and "content" keys.
// For vision, content can be an array of {type: "text"/"image_url", ...} parts.
func CallFireworks(model string, messages []map[string]interface{}, maxTokens int) (string, error) {
if FireworksKey == "" {
return "", fmt.Errorf("Fireworks API key not configured")
}
stream := maxTokens > 4096
reqBody := map[string]interface{}{
"model": model,
"messages": messages,
"max_tokens": maxTokens,
"temperature": 0.1,
"stream": stream,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequest("POST", "https://api.fireworks.ai/inference/v1/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+FireworksKey)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("API request: %w", err)
}
defer resp.Body.Close()
if !stream {
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != 200 {
msg := fmt.Sprintf("Fireworks API error %d: %s", resp.StatusCode, string(body))
if resp.StatusCode == 401 || resp.StatusCode == 402 || resp.StatusCode == 429 {
SendSignal("LLM: " + msg)
}
return "", fmt.Errorf("%s", msg)
}
var oaiResp struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(body, &oaiResp); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(oaiResp.Choices) == 0 {
return "", fmt.Errorf("empty response from Fireworks")
}
text := strings.TrimSpace(oaiResp.Choices[0].Message.Content)
text = strings.TrimPrefix(text, "```json")
text = strings.TrimPrefix(text, "```")
text = strings.TrimSuffix(text, "```")
return strings.TrimSpace(text), nil
}
// Streaming: read SSE chunks and accumulate content
if resp.StatusCode != 200 {
body, _ := io.ReadAll(resp.Body)
msg := fmt.Sprintf("Fireworks API error %d: %s", resp.StatusCode, string(body))
if resp.StatusCode == 401 || resp.StatusCode == 402 || resp.StatusCode == 429 {
SendSignal("LLM: " + msg)
}
return "", fmt.Errorf("%s", msg)
}
var sb strings.Builder
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 256*1024), 256*1024)
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, "data: ") {
continue
}
data := line[6:]
if data == "[DONE]" {
break
}
var chunk struct {
Choices []struct {
Delta struct {
Content string `json:"content"`
} `json:"delta"`
} `json:"choices"`
}
if json.Unmarshal([]byte(data), &chunk) == nil && len(chunk.Choices) > 0 {
sb.WriteString(chunk.Choices[0].Delta.Content)
}
}
text := strings.TrimSpace(sb.String())
text = strings.TrimPrefix(text, "```json")
text = strings.TrimPrefix(text, "```")
text = strings.TrimSuffix(text, "```")
return strings.TrimSpace(text), nil
}