pulse-monitor/gemini_ocr.go

388 lines
10 KiB
Go

package main
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"image"
"image/png"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"sync"
"time"
"gocv.io/x/gocv"
)
// Gemini OCR configuration
var (
GeminiAPIKey = "AIzaSyAsSUSCVs3SPXL7ugsbXa-chzcOKKJJrbA"
GeminiModel = "gemini-2.0-flash-exp" // Cheapest vision model
)
// OCR cache using image similarity (not hash!)
type OCRCache struct {
mu sync.RWMutex
entries []OCRCacheEntry
}
type OCRCacheEntry struct {
Image gocv.Mat // Store the actual image for comparison
Value int
Timestamp time.Time
}
var ocrCache = &OCRCache{
entries: []OCRCacheEntry{},
}
// ResetCache clears the OCR cache (call at startup)
func ResetOCRCache() {
ocrCache.mu.Lock()
// Close all cached images
for _, entry := range ocrCache.entries {
entry.Image.Close()
}
ocrCache.entries = []OCRCacheEntry{}
ocrCache.mu.Unlock()
logMessage(Console, Info, "🗑️ OCR cache cleared")
}
// findBestCacheMatch finds the cached entry with highest similarity to the input image
// Uses template matching with smaller template to handle position shifts
// Returns the entry and similarity score (0-1), or nil if cache is empty
func findBestCacheMatch(img gocv.Mat) (*OCRCacheEntry, float32) {
ocrCache.mu.RLock()
defer ocrCache.mu.RUnlock()
if len(ocrCache.entries) == 0 {
return nil, 0
}
var bestEntry *OCRCacheEntry
var bestScore float32 = 0
var bestIdx int = -1
for i := range ocrCache.entries {
entry := &ocrCache.entries[i]
if entry.Image.Empty() {
continue
}
// Create a smaller version of cached image (crop center 60%)
// This allows template matching to handle larger position shifts
marginX := entry.Image.Cols() / 5 // 20% margin each side
marginY := entry.Image.Rows() / 5
if marginX < 10 {
marginX = 10
}
if marginY < 10 {
marginY = 10
}
// Crop center portion as template
templateRect := image.Rect(marginX, marginY,
entry.Image.Cols()-marginX, entry.Image.Rows()-marginY)
template := entry.Image.Region(templateRect)
// Template must be smaller than image for proper matching
if template.Cols() >= img.Cols() || template.Rows() >= img.Rows() {
template.Close()
continue
}
// Compare using template matching - find best position
result := gocv.NewMat()
gocv.MatchTemplate(img, template, &result, gocv.TmCcoeffNormed, gocv.NewMat())
_, maxVal, _, _ := gocv.MinMaxLoc(result)
result.Close()
template.Close()
if maxVal > bestScore {
bestScore = maxVal
bestEntry = entry
bestIdx = i
}
}
if DEBUG && bestEntry != nil {
fmt.Printf(" Best match: cache[%d] value=%d similarity=%.4f\n", bestIdx, bestEntry.Value, bestScore)
}
return bestEntry, bestScore
}
// OCR stats for cost tracking
var (
OCRCallCount int
OCRCacheHits int
OCRTotalTokens int
)
// Rate limiting: track API call times to stay under 10/minute
var (
apiCallTimes []time.Time
apiCallTimesMu sync.Mutex
)
// waitForRateLimit ensures we don't exceed 10 API calls per minute
// Uses a conservative limit of 8 to provide buffer for timing variations
func waitForRateLimit() {
apiCallTimesMu.Lock()
for {
now := time.Now()
oneMinuteAgo := now.Add(-time.Minute)
// Remove calls older than 1 minute
validCalls := []time.Time{}
for _, t := range apiCallTimes {
if t.After(oneMinuteAgo) {
validCalls = append(validCalls, t)
}
}
apiCallTimes = validCalls
// If under 8 calls, we're good to proceed
if len(apiCallTimes) < 8 {
break
}
// Need to wait - calculate how long until oldest call expires
oldestCall := apiCallTimes[0]
waitUntil := oldestCall.Add(time.Minute).Add(time.Second) // +1s buffer
waitDuration := waitUntil.Sub(now)
if waitDuration <= 0 {
break
}
logMessage(Console, Info, "⏳ Rate limit: waiting %.0fs...", waitDuration.Seconds())
apiCallTimesMu.Unlock()
time.Sleep(waitDuration)
apiCallTimesMu.Lock()
// Loop back to re-check after waking
}
// Record this call
apiCallTimes = append(apiCallTimes, time.Now())
apiCallTimesMu.Unlock()
}
// resizeForOCR resizes image to small size for cheap API calls
// 140x85 is half of 280x170, should be readable for 2-3 digit numbers
func resizeForOCR(img gocv.Mat) gocv.Mat {
resized := gocv.NewMat()
gocv.Resize(img, &resized, image.Pt(140, 85), 0, 0, gocv.InterpolationLinear)
return resized
}
// matToBase64PNG converts a gocv.Mat to base64-encoded PNG
func matToBase64PNGForGemini(img gocv.Mat) (string, error) {
goImg, err := img.ToImage()
if err != nil {
return "", fmt.Errorf("failed to convert Mat to image: %w", err)
}
var buf bytes.Buffer
if err := png.Encode(&buf, goImg); err != nil {
return "", fmt.Errorf("failed to encode PNG: %w", err)
}
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
}
// GeminiOCR reads a number from an image using Gemini Vision API
// Returns the number and any error
func GeminiOCR(img gocv.Mat, displayName string) (int, error) {
// Resize for cheaper API calls
resized := resizeForOCR(img)
// DEBUG: Save images to see what we're comparing
if DEBUG {
filename := fmt.Sprintf("test_output/ocr_%s_%d.png", displayName, time.Now().UnixMilli()%10000)
gocv.IMWrite(filename, resized)
logMessage(Console, Debug, " Saved %s (%dx%d)", filename, resized.Cols(), resized.Rows())
}
// Check cache using image similarity
// Threshold 0.90: same digit with video noise ~93-94%, different digits ~6-9%
bestMatch, similarity := findBestCacheMatch(resized)
if DEBUG {
logMessage(Console, Debug, " %s cache: entries=%d, bestSimilarity=%.4f", displayName, len(ocrCache.entries), similarity)
}
if bestMatch != nil && similarity >= 0.90 {
resized.Close()
OCRCacheHits++
logMessage(Console, Debug, " %s: CACHE HIT (%.1f%%) → %d", displayName, similarity*100, bestMatch.Value)
return bestMatch.Value, nil
}
// Log cache miss reason
if bestMatch != nil {
logMessage(Console, Debug, " %s: CACHE MISS (%.1f%% < 90%%, bestMatch=%d) - calling API", displayName, similarity*100, bestMatch.Value)
} else {
logMessage(Console, Debug, " %s: CACHE EMPTY - calling API", displayName)
}
// Convert to base64
b64img, err := matToBase64PNGForGemini(resized)
if err != nil {
resized.Close()
return -1, err
}
// Rate limit before API call
waitForRateLimit()
// Call Gemini API
number, err := callGeminiAPI(b64img, displayName)
if err != nil {
resized.Close()
return -1, err
}
logMessage(Console, Debug, " %s: API CALL → %d (cache now has %d entries)", displayName, number, len(ocrCache.entries)+1)
// Cache the result with the actual image (don't close resized - it's now owned by cache)
ocrCache.mu.Lock()
ocrCache.entries = append(ocrCache.entries, OCRCacheEntry{
Image: resized, // Transfer ownership to cache
Value: number,
Timestamp: time.Now(),
})
ocrCache.mu.Unlock()
OCRCallCount++
return number, nil
}
// callGeminiAPI calls Google's Gemini Vision API
func callGeminiAPI(b64img, displayName string) (int, error) {
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s",
GeminiModel, GeminiAPIKey)
prompt := fmt.Sprintf("This image shows a %s reading from a pulse oximeter. What number is displayed? Reply with ONLY the number (2-3 digits). If unclear, reply FAIL.", displayName)
payload := map[string]interface{}{
"contents": []map[string]interface{}{
{
"parts": []map[string]interface{}{
{
"text": prompt,
},
{
"inline_data": map[string]string{
"mime_type": "image/png",
"data": b64img,
},
},
},
},
},
"generationConfig": map[string]interface{}{
"maxOutputTokens": 10,
"temperature": 0,
},
}
jsonData, err := json.Marshal(payload)
if err != nil {
return -1, err
}
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return -1, err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return -1, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return -1, err
}
if resp.StatusCode != 200 {
return -1, fmt.Errorf("Gemini API error %d: %s", resp.StatusCode, string(body))
}
// Parse response
var result struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
if err := json.Unmarshal(body, &result); err != nil {
return -1, fmt.Errorf("failed to parse response: %w", err)
}
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
return -1, fmt.Errorf("no response from Gemini")
}
// Track token usage
OCRTotalTokens += result.UsageMetadata.TotalTokenCount
// Extract number from response
content := result.Candidates[0].Content.Parts[0].Text
return extractNumberFromResponse(content)
}
// extractNumberFromResponse extracts a number from LLM response text
func extractNumberFromResponse(text string) (int, error) {
text = strings.TrimSpace(text)
// Check for FAIL response
if strings.Contains(strings.ToUpper(text), "FAIL") {
return -1, fmt.Errorf("LLM returned FAIL (unclear image)")
}
// Find all numbers in the text
re := regexp.MustCompile(`\d+`)
matches := re.FindAllString(text, -1)
if len(matches) == 0 {
return -1, fmt.Errorf("no number found in response: %s", text)
}
// Take the first number
num, err := strconv.Atoi(matches[0])
if err != nil {
return -1, err
}
return num, nil
}
// GetOCRStats returns current OCR statistics
func GetOCRStats() (calls int, cacheHits int, tokens int) {
return OCRCallCount, OCRCacheHits, OCRTotalTokens
}
// EstimateCost estimates the API cost based on token usage
// Gemini Flash: ~$0.075 per 1M input tokens, ~$0.30 per 1M output tokens
func EstimateCost() float64 {
// Rough estimate: mostly input tokens (images)
return float64(OCRTotalTokens) * 0.0000001 // ~$0.10 per 1M tokens average
}