151 lines
4.3 KiB
Go
151 lines
4.3 KiB
Go
package lib
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
)
|
|
|
|
var promptsDir string
|
|
|
|
// InitPrompts sets the directory where prompt files are located.
|
|
// This must be called by the main application at startup.
|
|
func InitPrompts(path string) {
|
|
promptsDir = path
|
|
}
|
|
|
|
// PromptsDir returns the configured prompts directory.
|
|
// This is used by local prompt loading functions in consumer packages.
|
|
func PromptsDir() string {
|
|
return promptsDir
|
|
}
|
|
|
|
// GeminiPart represents a single part in the Gemini content, can be text or inline_data.
|
|
type GeminiPart struct {
|
|
Text string `json:"text,omitempty"`
|
|
InlineData *GeminiInlineData `json:"inline_data,omitempty"`
|
|
}
|
|
|
|
// GeminiInlineData represents inline data for multimodal input.
|
|
type GeminiInlineData struct {
|
|
MimeType string `json:"mime_type"`
|
|
Data string `json:"data"`
|
|
}
|
|
|
|
// GeminiConfig allows overriding default generation parameters.
|
|
type GeminiConfig struct {
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxOutputTokens *int `json:"maxOutputTokens,omitempty"`
|
|
ResponseMimeType *string `json:"responseMimeType,omitempty"`
|
|
Model *string `json:"model,omitempty"` // Allows specifying a different model
|
|
}
|
|
|
|
// CallGemini sends a text-only request to the Gemini API with default configuration.
|
|
// It returns the raw text response from the model.
|
|
func CallGemini(prompt string) (string, error) {
|
|
parts := []GeminiPart{
|
|
{Text: prompt},
|
|
}
|
|
return CallGeminiMultimodal(parts, nil)
|
|
}
|
|
|
|
// CallGeminiMultimodal sends a request to the Gemini API with flexible content and configuration.
|
|
// It returns the raw text response from the model.
|
|
func CallGeminiMultimodal(parts []GeminiPart, config *GeminiConfig) (string, error) {
|
|
if GeminiKey == "" {
|
|
return "", fmt.Errorf("Gemini API key not configured")
|
|
}
|
|
|
|
// Default configuration
|
|
defaultTemperature := 0.1
|
|
defaultMaxOutputTokens := 2048
|
|
defaultResponseMimeType := "application/json"
|
|
defaultModel := "gemini-2.0-flash"
|
|
|
|
if config == nil {
|
|
config = &GeminiConfig{}
|
|
}
|
|
|
|
// Apply defaults if not overridden
|
|
if config.Temperature == nil {
|
|
config.Temperature = &defaultTemperature
|
|
}
|
|
if config.MaxOutputTokens == nil {
|
|
config.MaxOutputTokens = &defaultMaxOutputTokens
|
|
}
|
|
if config.ResponseMimeType == nil {
|
|
config.ResponseMimeType = &defaultResponseMimeType
|
|
}
|
|
if config.Model == nil {
|
|
config.Model = &defaultModel
|
|
}
|
|
|
|
reqBody := map[string]interface{}{
|
|
"contents": []map[string]interface{}{
|
|
{
|
|
"parts": parts,
|
|
},
|
|
},
|
|
"systemInstruction": map[string]interface{}{
|
|
"parts": []map[string]string{
|
|
{"text": "You are a JSON-only API. Output raw JSON with no markdown, no code fences, no explanations. Start directly with { and end with }."},
|
|
},
|
|
},
|
|
"generationConfig": map[string]interface{}{
|
|
"temperature": *config.Temperature,
|
|
"maxOutputTokens": *config.MaxOutputTokens,
|
|
"responseMimeType": *config.ResponseMimeType,
|
|
},
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", fmt.Errorf("marshal request: %w", err)
|
|
}
|
|
|
|
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", *config.Model, GeminiKey)
|
|
|
|
resp, err := http.Post(url, "application/json", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return "", fmt.Errorf("API request: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read response: %w", err)
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("Gemini API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var geminiResp struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
} `json:"candidates"`
|
|
}
|
|
if err := json.Unmarshal(body, &geminiResp); err != nil {
|
|
return "", fmt.Errorf("failed to parse Gemini response: %v", err)
|
|
}
|
|
|
|
if len(geminiResp.Candidates) == 0 || len(geminiResp.Candidates[0].Content.Parts) == 0 {
|
|
return "", fmt.Errorf("empty response from Gemini")
|
|
}
|
|
|
|
finalText := strings.TrimSpace(geminiResp.Candidates[0].Content.Parts[0].Text)
|
|
// The model sometimes still wraps the output in markdown, so we clean it.
|
|
finalText = strings.TrimPrefix(finalText, "```json")
|
|
finalText = strings.TrimPrefix(finalText, "```")
|
|
finalText = strings.TrimSuffix(finalText, "```")
|
|
|
|
return finalText, nil
|
|
}
|