inou/lib/llm.go

151 lines
4.3 KiB
Go

package lib
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
)
var promptsDir string
// InitPrompts sets the directory where tracker files are located.
// This must be called by the main application at startup.
func InitPrompts(path string) {
promptsDir = path
}
// TrackerPromptsDir returns the configured trackers directory.
// This is used by local tracker loading functions in consumer packages.
func TrackerPromptsDir() string {
return promptsDir
}
// GeminiPart represents a single part in the Gemini content, can be text or inline_data.
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
}
// GeminiInlineData represents inline data for multimodal input.
type GeminiInlineData struct {
MimeType string `json:"mime_type"`
Data string `json:"data"`
}
// GeminiConfig allows overriding default generation parameters.
type GeminiConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
ResponseMimeType *string `json:"responseMimeType,omitempty"`
Model *string `json:"model,omitempty"` // Allows specifying a different model
}
// CallGemini sends a text-only request to the Gemini API with default configuration.
// It returns the raw text response from the model.
func CallGemini(prompt string) (string, error) {
parts := []GeminiPart{
{Text: prompt},
}
return CallGeminiMultimodal(parts, nil)
}
// CallGeminiMultimodal sends a request to the Gemini API with flexible content and configuration.
// It returns the raw text response from the model.
func CallGeminiMultimodal(parts []GeminiPart, config *GeminiConfig) (string, error) {
if GeminiKey == "" {
return "", fmt.Errorf("Gemini API key not configured")
}
// Default configuration
defaultTemperature := 0.1
defaultMaxOutputTokens := 2048
defaultResponseMimeType := "application/json"
defaultModel := "gemini-2.0-flash"
if config == nil {
config = &GeminiConfig{}
}
// Apply defaults if not overridden
if config.Temperature == nil {
config.Temperature = &defaultTemperature
}
if config.MaxOutputTokens == nil {
config.MaxOutputTokens = &defaultMaxOutputTokens
}
if config.ResponseMimeType == nil {
config.ResponseMimeType = &defaultResponseMimeType
}
if config.Model == nil {
config.Model = &defaultModel
}
reqBody := map[string]interface{}{
"contents": []map[string]interface{}{
{
"parts": parts,
},
},
"systemInstruction": map[string]interface{}{
"parts": []map[string]string{
{"text": "You are a JSON-only API. Output raw JSON with no markdown, no code fences, no explanations. Start directly with { and end with }."},
},
},
"generationConfig": map[string]interface{}{
"temperature": *config.Temperature,
"maxOutputTokens": *config.MaxOutputTokens,
"responseMimeType": *config.ResponseMimeType,
},
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return "", fmt.Errorf("marshal request: %w", err)
}
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", *config.Model, GeminiKey)
resp, err := http.Post(url, "application/json", bytes.NewReader(jsonBody))
if err != nil {
return "", fmt.Errorf("API request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != 200 {
return "", fmt.Errorf("Gemini API error %d: %s", resp.StatusCode, string(body))
}
var geminiResp struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
if err := json.Unmarshal(body, &geminiResp); err != nil {
return "", fmt.Errorf("failed to parse Gemini response: %v", err)
}
if len(geminiResp.Candidates) == 0 || len(geminiResp.Candidates[0].Content.Parts) == 0 {
return "", fmt.Errorf("empty response from Gemini")
}
finalText := strings.TrimSpace(geminiResp.Candidates[0].Content.Parts[0].Text)
// The model sometimes still wraps the output in markdown, so we clean it.
finalText = strings.TrimPrefix(finalText, "```json")
finalText = strings.TrimPrefix(finalText, "```")
finalText = strings.TrimSuffix(finalText, "```")
return finalText, nil
}