257 lines
7.5 KiB
Go
257 lines
7.5 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"inou/lib"
|
|
)
|
|
|
|
// --- Local Structs for Prompt Processing ---
|
|
// These are defined in api/llm_types.go and should be used from there.
|
|
// They are commented out here to prevent redeclaration.
|
|
|
|
/*
|
|
type TriageResponse struct { ... }
|
|
type ExtractionResult struct { ... }
|
|
type InputConfig struct { ... }
|
|
type FormGroup struct { ... }
|
|
type FormField struct { ... }
|
|
type ScheduleSlot struct { ... }
|
|
type EntryData struct { ... }
|
|
var ValidCategories = map[string]bool{ ... }
|
|
*/
|
|
|
|
|
|
// --- API-Specific Logic ---
|
|
|
|
func loadLLMConfig() {
|
|
// Load GeminiKey from file or environment
|
|
data, err := os.ReadFile("anthropic.env")
|
|
if err != nil {
|
|
log.Printf("Warning: anthropic.env not found. Looking for GEMINI_API_KEY in environment.")
|
|
}
|
|
for _, line := range strings.Split(string(data), "\n") {
|
|
parts := strings.SplitN(line, "=", 2)
|
|
if len(parts) == 2 && parts[0] == "GEMINI_API_KEY" {
|
|
lib.GeminiKey = strings.TrimSpace(parts[1])
|
|
}
|
|
}
|
|
if lib.GeminiKey == "" {
|
|
lib.GeminiKey = os.Getenv("GEMINI_API_KEY")
|
|
}
|
|
if lib.GeminiKey != "" {
|
|
log.Println("Gemini API key loaded.")
|
|
} else {
|
|
log.Println("Warning: Gemini API key not found.")
|
|
}
|
|
|
|
// Initialize trackers directory
|
|
exe, _ := os.Executable()
|
|
promptsDir := filepath.Join(filepath.Dir(exe), "..", "api", "trackers")
|
|
if _, err := os.Stat(promptsDir); os.IsNotExist(err) {
|
|
promptsDir = "trackers" // Dev fallback
|
|
}
|
|
lib.InitPrompts(promptsDir)
|
|
log.Printf("Prompts directory set to: %s", lib.TrackerPromptsDir())
|
|
}
|
|
|
|
// callLLMForTracker is the main entry point for turning user text into a structured prompt.
|
|
func callLLMForTracker(userInput string, dossierID string) (*ExtractionResult, error) {
|
|
triage, err := runTriage(userInput, dossierID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if triage.Error != "" {
|
|
return &ExtractionResult{Error: triage.Error}, nil
|
|
}
|
|
|
|
existingTypes := getExistingTrackerTypes(dossierID) // Assuming db is accessible in api/main
|
|
return runExtraction(userInput, triage.Category, triage.Language, dossierID, existingTypes)
|
|
}
|
|
|
|
|
|
// --- Local Prompt Handling & DB Functions ---
|
|
|
|
func loadPrompt(name string) (string, error) {
|
|
path := filepath.Join(lib.TrackerPromptsDir(), name+".md")
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(data), nil
|
|
}
|
|
|
|
func runTriage(userInput string, dossierID string) (*TriageResponse, error) {
|
|
tmpl, err := loadPrompt("triage")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load triage prompt: %v", err)
|
|
}
|
|
|
|
// Get dossier context
|
|
dossier, _ := lib.DossierGet(nil, dossierID)
|
|
dossierName := "Patient"
|
|
dossierDOB := ""
|
|
if dossier != nil && dossier.Name != "" {
|
|
dossierName = dossier.Name
|
|
}
|
|
if dossier != nil && dossier.DateOfBirth != "" {
|
|
dossierDOB = dossier.DateOfBirth
|
|
}
|
|
|
|
prompt := strings.ReplaceAll(tmpl, "{{INPUT}}", userInput)
|
|
prompt = strings.ReplaceAll(prompt, "{{DOSSIER_NAME}}", dossierName)
|
|
prompt = strings.ReplaceAll(prompt, "{{DOSSIER_DOB}}", dossierDOB)
|
|
|
|
respText, err := lib.CallGemini(prompt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var result TriageResponse
|
|
if err := json.Unmarshal([]byte(respText), &result); err != nil {
|
|
var errMap map[string]string
|
|
if json.Unmarshal([]byte(respText), &errMap) == nil {
|
|
if errMsg, ok := errMap["error"]; ok {
|
|
result.Error = errMsg
|
|
return &result, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("failed to parse triage JSON: %v (raw: %s)", err, respText)
|
|
}
|
|
|
|
if _, ok := ValidCategories[result.Category]; !ok && result.Error == "" {
|
|
result.Category = "note"
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
func runExtraction(userInput, category, language, dossierID string, existingTypes map[string][]string) (*ExtractionResult, error) {
|
|
tmpl, err := loadPrompt(category)
|
|
if err != nil {
|
|
tmpl, err = loadPrompt("default")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to load prompt: %v", err)
|
|
}
|
|
}
|
|
|
|
// Get dossier name and DOB
|
|
dossier, _ := lib.DossierGet(nil, dossierID)
|
|
dossierName := "Patient"
|
|
dossierDOB := ""
|
|
if dossier != nil && dossier.Name != "" {
|
|
dossierName = dossier.Name
|
|
}
|
|
if dossier != nil && dossier.DateOfBirth != "" {
|
|
dossierDOB = dossier.DateOfBirth
|
|
}
|
|
|
|
// Get current date and year
|
|
now := time.Now()
|
|
currentDate := now.Format("2006-01-02")
|
|
currentYear := now.Format("2006")
|
|
|
|
var existingStr string
|
|
for cat, types := range existingTypes {
|
|
if len(types) > 0 {
|
|
existingStr += fmt.Sprintf("- %s: %v\n", cat, types)
|
|
}
|
|
}
|
|
if existingStr == "" {
|
|
existingStr = "(none yet)"
|
|
}
|
|
|
|
prompt := tmpl
|
|
prompt = strings.ReplaceAll(prompt, "{{INPUT}}", userInput)
|
|
prompt = strings.ReplaceAll(prompt, "{{LANGUAGE}}", language)
|
|
prompt = strings.ReplaceAll(prompt, "{{CATEGORY}}", category)
|
|
prompt = strings.ReplaceAll(prompt, "{{EXISTING_TYPES}}", existingStr)
|
|
prompt = strings.ReplaceAll(prompt, "{{DOSSIER_NAME}}", dossierName)
|
|
prompt = strings.ReplaceAll(prompt, "{{DOSSIER_DOB}}", dossierDOB)
|
|
prompt = strings.ReplaceAll(prompt, "{{CURRENT_DATE}}", currentDate)
|
|
prompt = strings.ReplaceAll(prompt, "{{CURRENT_YEAR}}", currentYear)
|
|
|
|
respText, err := lib.CallGemini(prompt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
log.Printf("Gemini raw response for %s: %s", category, respText)
|
|
|
|
// First try plural "entries" format (used by exercise.md and others)
|
|
var result ExtractionResult
|
|
if err := json.Unmarshal([]byte(respText), &result); err == nil {
|
|
log.Printf("Parsed as plural format: entries=%d", len(result.Entries))
|
|
if result.Category == "" {
|
|
result.Category = category
|
|
}
|
|
// If we got entries in plural format, return it
|
|
if len(result.Entries) > 0 || result.Error != "" {
|
|
return &result, nil
|
|
}
|
|
} else {
|
|
log.Printf("Plural parse failed: %v", err)
|
|
}
|
|
|
|
// Fallback: try singular "entry" format
|
|
var singleEntryResult struct {
|
|
Question string `json:"question"`
|
|
Category string `json:"category"`
|
|
Type string `json:"type"`
|
|
InputType string `json:"input_type"`
|
|
InputConfig InputConfig `json:"input_config"`
|
|
Schedule []ScheduleSlot `json:"schedule"`
|
|
Entry *EntryData `json:"entry,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
}
|
|
if err := json.Unmarshal([]byte(respText), &singleEntryResult); err == nil {
|
|
log.Printf("Parsed as singular format: entry=%v", singleEntryResult.Entry != nil)
|
|
result := ExtractionResult{
|
|
Question: singleEntryResult.Question,
|
|
Category: singleEntryResult.Category,
|
|
Type: singleEntryResult.Type,
|
|
InputType: singleEntryResult.InputType,
|
|
InputConfig: singleEntryResult.InputConfig,
|
|
Schedule: singleEntryResult.Schedule,
|
|
Error: singleEntryResult.Error,
|
|
}
|
|
if singleEntryResult.Entry != nil {
|
|
result.Entries = []*EntryData{singleEntryResult.Entry}
|
|
}
|
|
if result.Category == "" {
|
|
result.Category = category
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to parse extraction JSON in either format (raw: %s)", respText)
|
|
}
|
|
|
|
|
|
func getExistingTrackerTypes(dossierID string) map[string][]string {
|
|
result, err := lib.TrackerDistinctTypes(dossierID)
|
|
if err != nil {
|
|
log.Printf("Failed to get existing tracker types: %v", err)
|
|
return make(map[string][]string)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// --- Deprecated Anthropic/Sonnet Functions ---
|
|
// Kept for reference, but no longer used in the main flow.
|
|
|
|
var anthropicKey string
|
|
|
|
func callSonnet(prompt string) (string, error) {
|
|
return callSonnetWithRetry(prompt, 5, 15*time.Second)
|
|
}
|
|
|
|
func callSonnetWithRetry(prompt string, maxRetries int, baseDelay time.Duration) (string, error) {
|
|
// ... implementation remains the same, but is not called by the main tracker generation logic.
|
|
return "", fmt.Errorf("callSonnet is deprecated")
|
|
} |