388 lines
10 KiB
Go
388 lines
10 KiB
Go
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"image"
|
|
"image/png"
|
|
"io"
|
|
"net/http"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"gocv.io/x/gocv"
|
|
)
|
|
|
|
// Gemini OCR configuration
|
|
var (
|
|
GeminiAPIKey = "AIzaSyAsSUSCVs3SPXL7ugsbXa-chzcOKKJJrbA"
|
|
GeminiModel = "gemini-2.0-flash-exp" // Cheapest vision model
|
|
)
|
|
|
|
// OCR cache using image similarity (not hash!)
|
|
type OCRCache struct {
|
|
mu sync.RWMutex
|
|
entries []OCRCacheEntry
|
|
}
|
|
|
|
type OCRCacheEntry struct {
|
|
Image gocv.Mat // Store the actual image for comparison
|
|
Value int
|
|
Timestamp time.Time
|
|
}
|
|
|
|
var ocrCache = &OCRCache{
|
|
entries: []OCRCacheEntry{},
|
|
}
|
|
|
|
// ResetCache clears the OCR cache (call at startup)
|
|
func ResetOCRCache() {
|
|
ocrCache.mu.Lock()
|
|
// Close all cached images
|
|
for _, entry := range ocrCache.entries {
|
|
entry.Image.Close()
|
|
}
|
|
ocrCache.entries = []OCRCacheEntry{}
|
|
ocrCache.mu.Unlock()
|
|
logMessage(Console, Info, "🗑️ OCR cache cleared")
|
|
}
|
|
|
|
// findBestCacheMatch finds the cached entry with highest similarity to the input image
|
|
// Uses template matching with smaller template to handle position shifts
|
|
// Returns the entry and similarity score (0-1), or nil if cache is empty
|
|
func findBestCacheMatch(img gocv.Mat) (*OCRCacheEntry, float32) {
|
|
ocrCache.mu.RLock()
|
|
defer ocrCache.mu.RUnlock()
|
|
|
|
if len(ocrCache.entries) == 0 {
|
|
return nil, 0
|
|
}
|
|
|
|
var bestEntry *OCRCacheEntry
|
|
var bestScore float32 = 0
|
|
var bestIdx int = -1
|
|
|
|
for i := range ocrCache.entries {
|
|
entry := &ocrCache.entries[i]
|
|
if entry.Image.Empty() {
|
|
continue
|
|
}
|
|
|
|
// Create a smaller version of cached image (crop center 60%)
|
|
// This allows template matching to handle larger position shifts
|
|
marginX := entry.Image.Cols() / 5 // 20% margin each side
|
|
marginY := entry.Image.Rows() / 5
|
|
if marginX < 10 {
|
|
marginX = 10
|
|
}
|
|
if marginY < 10 {
|
|
marginY = 10
|
|
}
|
|
|
|
// Crop center portion as template
|
|
templateRect := image.Rect(marginX, marginY,
|
|
entry.Image.Cols()-marginX, entry.Image.Rows()-marginY)
|
|
template := entry.Image.Region(templateRect)
|
|
|
|
// Template must be smaller than image for proper matching
|
|
if template.Cols() >= img.Cols() || template.Rows() >= img.Rows() {
|
|
template.Close()
|
|
continue
|
|
}
|
|
|
|
// Compare using template matching - find best position
|
|
result := gocv.NewMat()
|
|
gocv.MatchTemplate(img, template, &result, gocv.TmCcoeffNormed, gocv.NewMat())
|
|
_, maxVal, _, _ := gocv.MinMaxLoc(result)
|
|
result.Close()
|
|
template.Close()
|
|
|
|
if maxVal > bestScore {
|
|
bestScore = maxVal
|
|
bestEntry = entry
|
|
bestIdx = i
|
|
}
|
|
}
|
|
|
|
if DEBUG && bestEntry != nil {
|
|
fmt.Printf(" Best match: cache[%d] value=%d similarity=%.4f\n", bestIdx, bestEntry.Value, bestScore)
|
|
}
|
|
|
|
return bestEntry, bestScore
|
|
}
|
|
|
|
// OCR stats for cost tracking
|
|
var (
|
|
OCRCallCount int
|
|
OCRCacheHits int
|
|
OCRTotalTokens int
|
|
)
|
|
|
|
// Rate limiting: track API call times to stay under 10/minute
|
|
var (
|
|
apiCallTimes []time.Time
|
|
apiCallTimesMu sync.Mutex
|
|
)
|
|
|
|
// waitForRateLimit ensures we don't exceed 10 API calls per minute
|
|
// Uses a conservative limit of 8 to provide buffer for timing variations
|
|
func waitForRateLimit() {
|
|
apiCallTimesMu.Lock()
|
|
|
|
for {
|
|
now := time.Now()
|
|
oneMinuteAgo := now.Add(-time.Minute)
|
|
|
|
// Remove calls older than 1 minute
|
|
validCalls := []time.Time{}
|
|
for _, t := range apiCallTimes {
|
|
if t.After(oneMinuteAgo) {
|
|
validCalls = append(validCalls, t)
|
|
}
|
|
}
|
|
apiCallTimes = validCalls
|
|
|
|
// If under 8 calls, we're good to proceed
|
|
if len(apiCallTimes) < 8 {
|
|
break
|
|
}
|
|
|
|
// Need to wait - calculate how long until oldest call expires
|
|
oldestCall := apiCallTimes[0]
|
|
waitUntil := oldestCall.Add(time.Minute).Add(time.Second) // +1s buffer
|
|
waitDuration := waitUntil.Sub(now)
|
|
if waitDuration <= 0 {
|
|
break
|
|
}
|
|
|
|
logMessage(Console, Info, "⏳ Rate limit: waiting %.0fs...", waitDuration.Seconds())
|
|
apiCallTimesMu.Unlock()
|
|
time.Sleep(waitDuration)
|
|
apiCallTimesMu.Lock()
|
|
// Loop back to re-check after waking
|
|
}
|
|
|
|
// Record this call
|
|
apiCallTimes = append(apiCallTimes, time.Now())
|
|
apiCallTimesMu.Unlock()
|
|
}
|
|
|
|
// resizeForOCR resizes image to small size for cheap API calls
|
|
// 140x85 is half of 280x170, should be readable for 2-3 digit numbers
|
|
func resizeForOCR(img gocv.Mat) gocv.Mat {
|
|
resized := gocv.NewMat()
|
|
gocv.Resize(img, &resized, image.Pt(140, 85), 0, 0, gocv.InterpolationLinear)
|
|
return resized
|
|
}
|
|
|
|
// matToBase64PNG converts a gocv.Mat to base64-encoded PNG
|
|
func matToBase64PNGForGemini(img gocv.Mat) (string, error) {
|
|
goImg, err := img.ToImage()
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to convert Mat to image: %w", err)
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
if err := png.Encode(&buf, goImg); err != nil {
|
|
return "", fmt.Errorf("failed to encode PNG: %w", err)
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(buf.Bytes()), nil
|
|
}
|
|
|
|
// GeminiOCR reads a number from an image using Gemini Vision API
|
|
// Returns the number and any error
|
|
func GeminiOCR(img gocv.Mat, displayName string) (int, error) {
|
|
// Resize for cheaper API calls
|
|
resized := resizeForOCR(img)
|
|
|
|
// DEBUG: Save images to see what we're comparing
|
|
if DEBUG {
|
|
filename := fmt.Sprintf("test_output/ocr_%s_%d.png", displayName, time.Now().UnixMilli()%10000)
|
|
gocv.IMWrite(filename, resized)
|
|
logMessage(Console, Debug, " Saved %s (%dx%d)", filename, resized.Cols(), resized.Rows())
|
|
}
|
|
|
|
// Check cache using image similarity
|
|
// Threshold 0.90: same digit with video noise ~93-94%, different digits ~6-9%
|
|
bestMatch, similarity := findBestCacheMatch(resized)
|
|
if DEBUG {
|
|
logMessage(Console, Debug, " %s cache: entries=%d, bestSimilarity=%.4f", displayName, len(ocrCache.entries), similarity)
|
|
}
|
|
if bestMatch != nil && similarity >= 0.90 {
|
|
resized.Close()
|
|
OCRCacheHits++
|
|
logMessage(Console, Debug, " %s: CACHE HIT (%.1f%%) → %d", displayName, similarity*100, bestMatch.Value)
|
|
return bestMatch.Value, nil
|
|
}
|
|
|
|
// Log cache miss reason
|
|
if bestMatch != nil {
|
|
logMessage(Console, Debug, " %s: CACHE MISS (%.1f%% < 90%%, bestMatch=%d) - calling API", displayName, similarity*100, bestMatch.Value)
|
|
} else {
|
|
logMessage(Console, Debug, " %s: CACHE EMPTY - calling API", displayName)
|
|
}
|
|
|
|
// Convert to base64
|
|
b64img, err := matToBase64PNGForGemini(resized)
|
|
if err != nil {
|
|
resized.Close()
|
|
return -1, err
|
|
}
|
|
|
|
// Rate limit before API call
|
|
waitForRateLimit()
|
|
|
|
// Call Gemini API
|
|
number, err := callGeminiAPI(b64img, displayName)
|
|
if err != nil {
|
|
resized.Close()
|
|
return -1, err
|
|
}
|
|
|
|
logMessage(Console, Debug, " %s: API CALL → %d (cache now has %d entries)", displayName, number, len(ocrCache.entries)+1)
|
|
|
|
// Cache the result with the actual image (don't close resized - it's now owned by cache)
|
|
ocrCache.mu.Lock()
|
|
ocrCache.entries = append(ocrCache.entries, OCRCacheEntry{
|
|
Image: resized, // Transfer ownership to cache
|
|
Value: number,
|
|
Timestamp: time.Now(),
|
|
})
|
|
ocrCache.mu.Unlock()
|
|
|
|
OCRCallCount++
|
|
return number, nil
|
|
}
|
|
|
|
// callGeminiAPI calls Google's Gemini Vision API
|
|
func callGeminiAPI(b64img, displayName string) (int, error) {
|
|
url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s",
|
|
GeminiModel, GeminiAPIKey)
|
|
|
|
prompt := fmt.Sprintf("This image shows a %s reading from a pulse oximeter. What number is displayed? Reply with ONLY the number (2-3 digits). If unclear, reply FAIL.", displayName)
|
|
|
|
payload := map[string]interface{}{
|
|
"contents": []map[string]interface{}{
|
|
{
|
|
"parts": []map[string]interface{}{
|
|
{
|
|
"text": prompt,
|
|
},
|
|
{
|
|
"inline_data": map[string]string{
|
|
"mime_type": "image/png",
|
|
"data": b64img,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"generationConfig": map[string]interface{}{
|
|
"maxOutputTokens": 10,
|
|
"temperature": 0,
|
|
},
|
|
}
|
|
|
|
jsonData, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
client := &http.Client{Timeout: 10 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
if resp.StatusCode != 200 {
|
|
return -1, fmt.Errorf("Gemini API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
// Parse response
|
|
var result struct {
|
|
Candidates []struct {
|
|
Content struct {
|
|
Parts []struct {
|
|
Text string `json:"text"`
|
|
} `json:"parts"`
|
|
} `json:"content"`
|
|
} `json:"candidates"`
|
|
UsageMetadata struct {
|
|
PromptTokenCount int `json:"promptTokenCount"`
|
|
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
TotalTokenCount int `json:"totalTokenCount"`
|
|
} `json:"usageMetadata"`
|
|
}
|
|
|
|
if err := json.Unmarshal(body, &result); err != nil {
|
|
return -1, fmt.Errorf("failed to parse response: %w", err)
|
|
}
|
|
|
|
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
|
|
return -1, fmt.Errorf("no response from Gemini")
|
|
}
|
|
|
|
// Track token usage
|
|
OCRTotalTokens += result.UsageMetadata.TotalTokenCount
|
|
|
|
// Extract number from response
|
|
content := result.Candidates[0].Content.Parts[0].Text
|
|
return extractNumberFromResponse(content)
|
|
}
|
|
|
|
// extractNumberFromResponse extracts a number from LLM response text
|
|
func extractNumberFromResponse(text string) (int, error) {
|
|
text = strings.TrimSpace(text)
|
|
|
|
// Check for FAIL response
|
|
if strings.Contains(strings.ToUpper(text), "FAIL") {
|
|
return -1, fmt.Errorf("LLM returned FAIL (unclear image)")
|
|
}
|
|
|
|
// Find all numbers in the text
|
|
re := regexp.MustCompile(`\d+`)
|
|
matches := re.FindAllString(text, -1)
|
|
|
|
if len(matches) == 0 {
|
|
return -1, fmt.Errorf("no number found in response: %s", text)
|
|
}
|
|
|
|
// Take the first number
|
|
num, err := strconv.Atoi(matches[0])
|
|
if err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
return num, nil
|
|
}
|
|
|
|
// GetOCRStats returns current OCR statistics
|
|
func GetOCRStats() (calls int, cacheHits int, tokens int) {
|
|
return OCRCallCount, OCRCacheHits, OCRTotalTokens
|
|
}
|
|
|
|
// EstimateCost estimates the API cost based on token usage
|
|
// Gemini Flash: ~$0.075 per 1M input tokens, ~$0.30 per 1M output tokens
|
|
func EstimateCost() float64 {
|
|
// Rough estimate: mostly input tokens (images)
|
|
return float64(OCRTotalTokens) * 0.0000001 // ~$0.10 per 1M tokens average
|
|
}
|