472 lines
13 KiB
Go
472 lines
13 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
// PatternStore holds patterns and their values
|
|
type PatternStore struct {
|
|
mu sync.RWMutex
|
|
Patterns map[string]int // fingerprint -> value
|
|
file string
|
|
}
|
|
|
|
// Unlabeled holds patterns waiting for manual labeling
|
|
type Unlabeled struct {
|
|
ID int `json:"id"`
|
|
Fingerprint string `json:"fingerprint"`
|
|
ImagePath string `json:"imagePath"`
|
|
Side string `json:"side"` // "left" or "right"
|
|
}
|
|
|
|
var unlabeled []Unlabeled
|
|
var unlabeledMu sync.Mutex
|
|
var unlabeledID int
|
|
var patternStore *PatternStore
|
|
|
|
// NewPatternStore loads patterns from CSV file (fingerprint,value)
|
|
func NewPatternStore(file string) (*PatternStore, error) {
|
|
ps := &PatternStore{
|
|
Patterns: make(map[string]int),
|
|
file: file,
|
|
}
|
|
|
|
f, err := os.Open(file)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
patternStore = ps
|
|
return ps, nil // Empty store
|
|
}
|
|
return nil, err
|
|
}
|
|
defer f.Close()
|
|
|
|
scanner := bufio.NewScanner(f)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
idx := strings.LastIndex(line, ",")
|
|
if idx < 0 {
|
|
continue
|
|
}
|
|
fingerprint := line[:idx]
|
|
value, err := strconv.Atoi(line[idx+1:])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
ps.Patterns[fingerprint] = value
|
|
}
|
|
|
|
patternStore = ps
|
|
return ps, nil
|
|
}
|
|
|
|
// Lookup returns the value for a fingerprint
|
|
func (ps *PatternStore) Lookup(fingerprint string) (int, bool) {
|
|
ps.mu.RLock()
|
|
defer ps.mu.RUnlock()
|
|
val, ok := ps.Patterns[fingerprint]
|
|
return val, ok
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|
|
|
|
// Store saves a pattern -> value mapping
|
|
func (ps *PatternStore) Store(fingerprint string, value int) error {
|
|
ps.mu.Lock()
|
|
ps.Patterns[fingerprint] = value
|
|
count := len(ps.Patterns)
|
|
ps.mu.Unlock()
|
|
fmt.Printf("Stored pattern (total: %d)\n", count)
|
|
return ps.Save()
|
|
}
|
|
|
|
// Save writes patterns to CSV file (fingerprint,value)
|
|
func (ps *PatternStore) Save() error {
|
|
ps.mu.RLock()
|
|
defer ps.mu.RUnlock()
|
|
|
|
var lines []string
|
|
for fp, val := range ps.Patterns {
|
|
lines = append(lines, fmt.Sprintf("%s,%d", fp, val))
|
|
}
|
|
return os.WriteFile(ps.file, []byte(strings.Join(lines, "\n")+"\n"), 0644)
|
|
}
|
|
|
|
// Count returns number of patterns
|
|
func (ps *PatternStore) Count() int {
|
|
ps.mu.RLock()
|
|
defer ps.mu.RUnlock()
|
|
return len(ps.Patterns)
|
|
}
|
|
|
|
// AddUnlabeled adds a pattern for manual labeling, returns true if added (not duplicate)
|
|
func AddUnlabeled(fingerprint, imagePath, side string) bool {
|
|
unlabeledMu.Lock()
|
|
defer unlabeledMu.Unlock()
|
|
|
|
// Check if already in list
|
|
for _, u := range unlabeled {
|
|
if u.Fingerprint == fingerprint {
|
|
return false
|
|
}
|
|
}
|
|
|
|
unlabeledID++
|
|
unlabeled = append(unlabeled, Unlabeled{
|
|
ID: unlabeledID,
|
|
Fingerprint: fingerprint,
|
|
ImagePath: imagePath,
|
|
Side: side,
|
|
})
|
|
fmt.Printf("New unlabeled %s: %s (queue: %d)\n", side, fingerprint, len(unlabeled))
|
|
return true
|
|
}
|
|
|
|
// RemoveUnlabeled removes a pattern from the unlabeled list
|
|
func RemoveUnlabeled(fingerprint string) {
|
|
unlabeledMu.Lock()
|
|
defer unlabeledMu.Unlock()
|
|
|
|
for i, u := range unlabeled {
|
|
if u.Fingerprint == fingerprint {
|
|
unlabeled = append(unlabeled[:i], unlabeled[i+1:]...)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// StartWebServer starts the training web interface
|
|
func StartWebServer(port int) {
|
|
http.HandleFunc("/", handleIndex)
|
|
http.HandleFunc("/api/unlabeled", handleUnlabeled)
|
|
http.HandleFunc("/api/label", handleLabel)
|
|
http.HandleFunc("/api/stats", handleStats)
|
|
|
|
addr := fmt.Sprintf(":%d", port)
|
|
fmt.Printf("Training interface at http://localhost%s\n", addr)
|
|
go http.ListenAndServe(addr, nil)
|
|
}
|
|
|
|
func handleIndex(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "text/html")
|
|
w.Write([]byte(indexHTML))
|
|
}
|
|
|
|
func handleUnlabeled(w http.ResponseWriter, r *http.Request) {
|
|
unlabeledMu.Lock()
|
|
defer unlabeledMu.Unlock()
|
|
|
|
// Build response with embedded images
|
|
type UnlabeledResponse struct {
|
|
ID int `json:"id"`
|
|
Fingerprint string `json:"fingerprint"`
|
|
ImageData string `json:"imageData"`
|
|
Side string `json:"side"`
|
|
}
|
|
|
|
var resp []UnlabeledResponse
|
|
for _, u := range unlabeled {
|
|
imgData, err := os.ReadFile(u.ImagePath)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
resp = append(resp, UnlabeledResponse{
|
|
ID: u.ID,
|
|
Fingerprint: u.Fingerprint,
|
|
ImageData: base64.StdEncoding.EncodeToString(imgData),
|
|
Side: u.Side,
|
|
})
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(resp)
|
|
}
|
|
|
|
func handleLabel(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
fingerprint := r.FormValue("fingerprint")
|
|
valueStr := r.FormValue("value")
|
|
|
|
if fingerprint == "" || valueStr == "" {
|
|
http.Error(w, "Missing fingerprint or value", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// "x" or "-1" means ignore this pattern
|
|
var value int
|
|
if valueStr == "x" || valueStr == "X" || valueStr == "-1" {
|
|
value = -1
|
|
} else {
|
|
var err error
|
|
value, err = strconv.Atoi(valueStr)
|
|
if err != nil {
|
|
http.Error(w, "Invalid value", http.StatusBadRequest)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Store pattern
|
|
if patternStore != nil {
|
|
if err := patternStore.Store(fingerprint, value); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
fmt.Printf("Learned: %s = %d\n", fingerprint, value)
|
|
}
|
|
|
|
// Remove from unlabeled
|
|
RemoveUnlabeled(fingerprint)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
|
}
|
|
|
|
func handleStats(w http.ResponseWriter, r *http.Request) {
|
|
unlabeledMu.Lock()
|
|
unlabeledCount := len(unlabeled)
|
|
unlabeledMu.Unlock()
|
|
|
|
patternCount := 0
|
|
if patternStore != nil {
|
|
patternCount = patternStore.Count()
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]int{
|
|
"patterns": patternCount,
|
|
"unlabeled": unlabeledCount,
|
|
})
|
|
}
|
|
|
|
const indexHTML = `<!DOCTYPE html>
|
|
<html>
|
|
<head>
|
|
<title>Pulse Monitor Training</title>
|
|
<style>
|
|
* { box-sizing: border-box; }
|
|
body {
|
|
font-family: system-ui, -apple-system, sans-serif;
|
|
background: #1a1a2e;
|
|
color: #eee;
|
|
padding: 20px;
|
|
margin: 0;
|
|
}
|
|
h1 { color: #0f0; margin: 0 0 10px 0; }
|
|
.stats {
|
|
background: #16213e;
|
|
padding: 15px;
|
|
border-radius: 8px;
|
|
margin-bottom: 20px;
|
|
display: flex;
|
|
gap: 30px;
|
|
}
|
|
.stat { text-align: center; }
|
|
.stat-value { font-size: 32px; font-weight: bold; color: #0f0; }
|
|
.stat-label { font-size: 12px; color: #888; }
|
|
.patterns {
|
|
display: flex;
|
|
flex-wrap: wrap;
|
|
gap: 15px;
|
|
}
|
|
.pattern {
|
|
background: #16213e;
|
|
padding: 15px;
|
|
border-radius: 8px;
|
|
text-align: center;
|
|
border: 2px solid transparent;
|
|
}
|
|
.pattern.left { border-color: #4a9; }
|
|
.pattern.right { border-color: #94a; }
|
|
.pattern img {
|
|
display: block;
|
|
margin: 0 auto 10px;
|
|
image-rendering: pixelated;
|
|
max-width: 300px;
|
|
}
|
|
.pattern input {
|
|
width: 80px;
|
|
font-size: 20px;
|
|
padding: 8px;
|
|
text-align: center;
|
|
border: none;
|
|
border-radius: 4px;
|
|
background: #0f3460;
|
|
color: #fff;
|
|
}
|
|
.pattern input:focus {
|
|
outline: 2px solid #0f0;
|
|
}
|
|
.fingerprint {
|
|
font-size: 11px;
|
|
color: #666;
|
|
word-break: break-all;
|
|
max-width: 250px;
|
|
margin-top: 10px;
|
|
}
|
|
.side-label {
|
|
font-size: 12px;
|
|
font-weight: bold;
|
|
margin-bottom: 8px;
|
|
}
|
|
.side-label.left { color: #4a9; }
|
|
.side-label.right { color: #94a; }
|
|
.empty {
|
|
color: #666;
|
|
padding: 40px;
|
|
text-align: center;
|
|
}
|
|
.saved {
|
|
animation: flash 0.5s;
|
|
}
|
|
@keyframes flash {
|
|
0% { background: #0f0; }
|
|
100% { background: #16213e; }
|
|
}
|
|
</style>
|
|
</head>
|
|
<body>
|
|
<h1>Pulse Monitor Training</h1>
|
|
<div class="stats">
|
|
<div class="stat">
|
|
<div class="stat-value" id="patternCount">-</div>
|
|
<div class="stat-label">Learned Patterns</div>
|
|
</div>
|
|
<div class="stat">
|
|
<div class="stat-value" id="unlabeledCount">-</div>
|
|
<div class="stat-label">Waiting for Label</div>
|
|
</div>
|
|
</div>
|
|
<div class="patterns" id="patterns">
|
|
<div class="empty">Waiting for patterns...</div>
|
|
</div>
|
|
|
|
<script>
|
|
let knownPatterns = new Set();
|
|
|
|
async function refresh() {
|
|
// Update stats
|
|
const stats = await fetch('/api/stats').then(r => r.json());
|
|
document.getElementById('patternCount').textContent = stats.patterns;
|
|
document.getElementById('unlabeledCount').textContent = stats.unlabeled;
|
|
|
|
// Get unlabeled patterns
|
|
const unlabeled = await fetch('/api/unlabeled').then(r => r.json());
|
|
const container = document.getElementById('patterns');
|
|
|
|
if (!unlabeled || unlabeled.length === 0) {
|
|
if (container.querySelector('.pattern')) {
|
|
container.innerHTML = '<div class="empty">All patterns labeled! Waiting for new ones...</div>';
|
|
}
|
|
return;
|
|
}
|
|
|
|
// Add new patterns
|
|
unlabeled.forEach(p => {
|
|
if (knownPatterns.has(p.fingerprint)) return;
|
|
knownPatterns.add(p.fingerprint);
|
|
|
|
// Remove empty message if present
|
|
const empty = container.querySelector('.empty');
|
|
if (empty) empty.remove();
|
|
|
|
const div = document.createElement('div');
|
|
div.className = 'pattern ' + p.side;
|
|
div.id = 'p-' + p.id;
|
|
div.innerHTML = ` + "`" + `
|
|
<div class="side-label ${p.side}">${p.side.toUpperCase()}</div>
|
|
<img src="data:image/png;base64,${p.imageData}">
|
|
<input type="text" placeholder="value" data-fp="${p.fingerprint}">
|
|
<div class="fingerprint">${p.fingerprint}</div>
|
|
` + "`" + `;
|
|
container.appendChild(div);
|
|
div.querySelector('input').addEventListener('keydown', handleKey);
|
|
});
|
|
|
|
// Remove patterns that are no longer unlabeled
|
|
const currentFingerprints = new Set(unlabeled.map(p => p.fingerprint));
|
|
container.querySelectorAll('.pattern').forEach(div => {
|
|
const fp = div.querySelector('.fingerprint').textContent;
|
|
if (!currentFingerprints.has(fp)) {
|
|
div.classList.add('saved');
|
|
setTimeout(() => div.remove(), 500);
|
|
knownPatterns.delete(fp);
|
|
}
|
|
});
|
|
|
|
// Focus first input only if no input currently has focus
|
|
const activeEl = document.activeElement;
|
|
const hasInputFocus = activeEl && activeEl.tagName === 'INPUT' && activeEl.closest('.pattern');
|
|
if (!hasInputFocus) {
|
|
const firstInput = container.querySelector('.pattern input:not(:disabled)');
|
|
if (firstInput) firstInput.focus();
|
|
}
|
|
}
|
|
|
|
async function handleKey(e) {
|
|
if (e.key !== 'Enter') return;
|
|
e.preventDefault();
|
|
|
|
const input = e.target;
|
|
const value = input.value.trim();
|
|
if (!value) return;
|
|
|
|
const fingerprint = input.dataset.fp;
|
|
const card = input.closest('.pattern');
|
|
|
|
// Immediately find and focus next input BEFORE any async work
|
|
const allInputs = Array.from(document.querySelectorAll('.pattern input'));
|
|
const idx = allInputs.indexOf(input);
|
|
const nextInput = allInputs[idx + 1] || allInputs[0];
|
|
if (nextInput && nextInput !== input) {
|
|
nextInput.focus();
|
|
}
|
|
|
|
// Disable this input and mark card as saving
|
|
input.disabled = true;
|
|
card.style.opacity = '0.5';
|
|
|
|
// Save to server
|
|
const formData = new FormData();
|
|
formData.append('fingerprint', fingerprint);
|
|
formData.append('value', value);
|
|
await fetch('/api/label', { method: 'POST', body: formData });
|
|
|
|
// Remove card
|
|
knownPatterns.delete(fingerprint);
|
|
card.classList.add('saved');
|
|
setTimeout(() => card.remove(), 300);
|
|
}
|
|
|
|
// Poll for updates
|
|
refresh();
|
|
setInterval(refresh, 1000);
|
|
</script>
|
|
</body>
|
|
</html>`
|
|
|
|
// GenerateHTML creates training.html with unlabeled patterns (legacy)
|
|
func GenerateHTML(dir string) error {
|
|
// No longer needed - web server handles this
|
|
return nil
|
|
}
|