250 lines
7.2 KiB
Go
250 lines
7.2 KiB
Go
package lib
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
var promptsDir string
|
|
|
|
// InitPrompts sets the directory where tracker files are located.
|
|
// This must be called by the main application at startup.
|
|
func InitPrompts(path string) {
|
|
promptsDir = path
|
|
}
|
|
|
|
// TrackerPromptsDir returns the configured trackers directory.
|
|
// This is used by local tracker loading functions in consumer packages.
|
|
func TrackerPromptsDir() string {
|
|
return promptsDir
|
|
}
|
|
|
|
// GeminiPart represents a single part in the Gemini content, can be text or inline_data.
|
|
type GeminiPart struct {
|
|
Text string `json:"text,omitempty"`
|
|
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
|
|
}
|
|
|
|
// GeminiInlineData represents inline data for multimodal input.
|
|
type GeminiInlineData struct {
|
|
MimeType string `json:"mime_type"`
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
// GeminiConfig allows overriding default generation parameters.
|
|
type GeminiConfig struct {
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
ResponseMimeType *string `json:"responseMimeType,omitempty"`
|
|
Model *string `json:"model,omitempty"` // Allows specifying a different model
|
|
}
|
|
|
|
// CallGemini sends a text-only request to the Gemini API with default configuration.
|
|
// It returns the raw text response from the model.
|
|
func CallGemini(prompt string) (string, error) {
|
|
parts := []GeminiPart{
|
|
{Text: prompt},
|
|
}
|
|
return CallGeminiMultimodal(parts, nil)
|
|
}
|
|
|
|
// CallGeminiMultimodal sends a request to the Gemini API with flexible content and configuration.
|
|
// It returns the raw text response from the model.
|
|
func CallGeminiMultimodal(parts []GeminiPart, config *GeminiConfig) (string, error) {
|
|
if GeminiKey == "" {
|
|
return "", fmt.Errorf("Gemini API key not configured")
|
|
}
|
|
|
|
// Default configuration
|
|
defaultTemperature := 0.1
|
|
defaultMaxOutputTokens := 2048
|
|
defaultResponseMimeType := "application/json"
|
|
defaultModel := "gemini-2.0-flash"
|
|
|
|
if config == nil {
|
|
config = &GeminiConfig{}
|
|
}
|
|
|
|
// Apply defaults if not overridden
|
|
if config.Temperature == nil {
|
|
config.Temperature = &defaultTemperature
|
|
}
|
|
if config.MaxOutputTokens == nil {
|
|
config.MaxOutputTokens = &defaultMaxOutputTokens
|
|
}
|
|
if config.ResponseMimeType == nil {
|
|
config.ResponseMimeType = &defaultResponseMimeType
|
|
}
|
|
if config.Model == nil {
|
|
config.Model = &defaultModel
|
|
}
|
|
|
|
reqBody := map[string]interface{}{
|
|
"contents": []map[string]interface{}{
|
|
{
|
|
"parts": parts,
|
|
},
|
|
},
|
|
"systemInstruction": map[string]interface{}{
|
|
"parts": []map[string]string{
|
|
{"text": "You are a JSON-only API. Output raw JSON with no markdown, no code fences, no explanations. Start directly with { and end with }."},
|
|
},
|
|
},
|
|
"generationConfig": map[string]interface{}{
|
|
"temperature": *config.Temperature,
|
|
"maxOutputTokens": *config.MaxOutputTokens,
|
|
"responseMimeType": *config.ResponseMimeType,
|
|
},
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", *config.Model, GeminiKey)
|
|
|
|
resp, err := http.Post(url, "application/json", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return "", fmt.Errorf("API request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("Gemini API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var geminiResp struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
} `json:"candidates"`
|
|
}
|
|
if err := json.Unmarshal(body, &geminiResp); err != nil {
|
|
return "", fmt.Errorf("failed to parse Gemini response: %v", err)
|
|
}
|
|
|
|
if len(geminiResp.Candidates) == 0 || len(geminiResp.Candidates[0].Content.Parts) == 0 {
|
|
return "", fmt.Errorf("empty response from Gemini")
|
|
}
|
|
|
|
finalText := strings.TrimSpace(geminiResp.Candidates[0].Content.Parts[0].Text)
|
|
// The model sometimes still wraps the output in markdown, so we clean it.
|
|
finalText = strings.TrimPrefix(finalText, "```json")
|
|
finalText = strings.TrimPrefix(finalText, "```")
|
|
finalText = strings.TrimSuffix(finalText, "```")
|
|
|
|
return finalText, nil
|
|
}
|
|
|
|
// CallFireworks sends a request to the Fireworks AI API (OpenAI-compatible).
|
|
// messages should be OpenAI-format: []map[string]interface{} with "role" and "content" keys.
|
|
// For vision, content can be an array of {type: "text"/"image_url", ...} parts.
|
|
func CallFireworks(model string, messages []map[string]interface{}, maxTokens int) (string, error) {
|
|
if FireworksKey == "" {
|
|
return "", fmt.Errorf("Fireworks API key not configured")
|
|
}
|
|
|
|
stream := maxTokens > 4096
|
|
reqBody := map[string]interface{}{
|
|
"model": model,
|
|
"messages": messages,
|
|
"max_tokens": maxTokens,
|
|
"temperature": 0.1,
|
|
"stream": stream,
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", "https://api.fireworks.ai/inference/v1/chat/completions", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return "", fmt.Errorf("create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+FireworksKey)
|
|
|
|
resp, err := http.DefaultClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("API request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if !stream {
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read response: %w", err)
|
|
}
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("Fireworks API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
var oaiResp struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
}
|
|
if err := json.Unmarshal(body, &oaiResp); err != nil {
|
|
return "", fmt.Errorf("parse response: %w", err)
|
|
}
|
|
if len(oaiResp.Choices) == 0 {
|
|
return "", fmt.Errorf("empty response from Fireworks")
|
|
}
|
|
text := strings.TrimSpace(oaiResp.Choices[0].Message.Content)
|
|
text = strings.TrimPrefix(text, "```json")
|
|
text = strings.TrimPrefix(text, "```")
|
|
text = strings.TrimSuffix(text, "```")
|
|
return strings.TrimSpace(text), nil
|
|
}
|
|
|
|
// Streaming: read SSE chunks and accumulate content
|
|
if resp.StatusCode != 200 {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return "", fmt.Errorf("Fireworks API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
var sb strings.Builder
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
scanner.Buffer(make([]byte, 256*1024), 256*1024)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if !strings.HasPrefix(line, "data: ") {
|
|
continue
|
|
}
|
|
data := line[6:]
|
|
if data == "[DONE]" {
|
|
break
|
|
}
|
|
var chunk struct {
|
|
Choices []struct {
|
|
Delta struct {
|
|
Content string `json:"content"`
|
|
} `json:"delta"`
|
|
} `json:"choices"`
|
|
}
|
|
if json.Unmarshal([]byte(data), &chunk) == nil && len(chunk.Choices) > 0 {
|
|
sb.WriteString(chunk.Choices[0].Delta.Content)
|
|
}
|
|
}
|
|
text := strings.TrimSpace(sb.String())
|
|
text = strings.TrimPrefix(text, "```json")
|
|
text = strings.TrimPrefix(text, "```")
|
|
text = strings.TrimSuffix(text, "```")
|
|
return strings.TrimSpace(text), nil
|
|
}
|