inou/tools/test-doc-import/main.go

283 lines
8.4 KiB
Go

package main
import (
"encoding/base64"
"encoding/json"
"fmt"
"inou/lib"
"log"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"sync"
"time"
)
type extractedEntry struct {
Type string `json:"type"`
Value string `json:"value"`
Summary string `json:"summary"`
SummaryTranslated string `json:"summary_translated,omitempty"`
SearchKey string `json:"search_key,omitempty"`
Timestamp string `json:"timestamp,omitempty"`
Data map[string]interface{} `json:"data"`
SourceSpans []sourceSpan `json:"source_spans,omitempty"`
}
type sourceSpan struct {
Start string `json:"start"`
End string `json:"end"`
}
var extractionPreamble = `IMPORTANT RULES (apply to all entries you return):
- Do NOT translate. Keep ALL text values (summary, value, data fields) in the ORIGINAL language of the document.
- For each entry, include "source_spans": an array of {"start": "...", "end": "..."} where start/end are the VERBATIM first and last 5-8 words of the relevant passage(s) in the source markdown. This is used to highlight the source text. Multiple spans are allowed.
- For each entry, include "search_key": a short normalized deduplication key in English lowercase. Format: "thing:qualifier:YYYY-MM" or "thing:qualifier" for undated facts. Examples: "surgery:vp-shunt:2020-07", "device:ommaya-reservoir:2020-04", "diagnosis:hydrocephalus", "provider:peraud:ulm". Same real-world fact across different documents MUST produce the same key.
`
// loadExtractionPrompts discovers all extract_*.md files and returns {categoryID: prompt content}.
func loadExtractionPrompts() map[int]string {
pattern := filepath.Join(lib.TrackerPromptsDir(), "extract_*.md")
files, _ := filepath.Glob(pattern)
prompts := make(map[int]string)
for _, f := range files {
base := filepath.Base(f)
name := strings.TrimPrefix(base, "extract_")
name = strings.TrimSuffix(name, ".md")
catID, ok := lib.CategoryFromString[name]
if !ok {
fmt.Printf("Unknown category in prompt file: %s\n", base)
continue
}
data, err := os.ReadFile(f)
if err != nil {
continue
}
prompts[catID] = string(data)
}
return prompts
}
const (
visionModel = "accounts/fireworks/models/qwen3-vl-30b-a3b-instruct"
textModel = "accounts/fireworks/models/qwen3-vl-30b-a3b-instruct"
)
var ocrPrompt = `You are a medical document OCR system. Produce a faithful markdown transcription of this document.
The images are sequential pages of the same document. Process them in order: page 1 first, then page 2, etc.
Rules:
- Read each page top-to-bottom, left-to-right
- Preserve ALL text, dates, values, names, addresses, and structure
- Translate nothing — keep the original language
- Use markdown headers, lists, and formatting to reflect the document structure
- For tables, use markdown tables. Preserve numeric values exactly.
- Be complete — do not skip or summarize anything
- Do not describe visual elements (logos, signatures) — only transcribe text
- For handwritten text, transcribe as accurately as possible. Mark uncertain readings with [?]`
func main() {
if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: test-doc-import <dossierID> <pdf-path>\n")
os.Exit(1)
}
dossierID := os.Args[1]
pdfPath := os.Args[2]
fileName := filepath.Base(pdfPath)
if err := lib.Init(); err != nil {
log.Fatalf("lib.Init: %v", err)
}
lib.ConfigInit()
lib.InitPrompts("tracker_prompts")
fmt.Printf("Prompts dir: %s\n", lib.TrackerPromptsDir())
// 1. Convert PDF to PNG pages
tempDir, _ := os.MkdirTemp("", "doc-import-*")
defer os.RemoveAll(tempDir)
prefix := filepath.Join(tempDir, "page")
cmd := exec.Command("pdftoppm", "-png", "-r", "200", pdfPath, prefix)
if out, err := cmd.CombinedOutput(); err != nil {
log.Fatalf("pdftoppm: %v: %s", err, out)
}
pageFiles, _ := filepath.Glob(prefix + "*.png")
sort.Strings(pageFiles)
fmt.Printf("%d pages converted\n", len(pageFiles))
// 2. OCR
content := []interface{}{
map[string]string{"type": "text", "text": ocrPrompt},
}
for _, pf := range pageFiles {
imgBytes, _ := os.ReadFile(pf)
b64 := base64.StdEncoding.EncodeToString(imgBytes)
content = append(content, map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": "data:image/png;base64," + b64,
},
})
}
fmt.Printf("Calling OCR...\n")
start := time.Now()
markdown, err := lib.CallFireworks(visionModel, []map[string]interface{}{
{"role": "user", "content": content},
}, 16384)
if err != nil {
log.Fatalf("OCR: %v", err)
}
fmt.Printf("OCR done: %d chars in %.1fs\n", len(markdown), time.Since(start).Seconds())
// 3. Create document entry
now := time.Now().Unix()
docData := map[string]interface{}{
"markdown": markdown,
"pages": len(pageFiles),
}
docDataJSON, _ := json.Marshal(docData)
docEntry := &lib.Entry{
DossierID: dossierID,
Category: lib.CategoryDocument,
Type: "pdf",
Value: fileName,
Timestamp: now,
Data: string(docDataJSON),
}
lib.EntryWrite("", docEntry)
docID := docEntry.EntryID
fmt.Printf("Document entry: %s\n", docID)
// 4. Fan out extraction
type catResult struct {
Category int
Entries []extractedEntry
}
var mu sync.Mutex
var results []catResult
var wg sync.WaitGroup
prompts := loadExtractionPrompts()
fmt.Printf("Starting %d extraction calls...\n", len(prompts))
extractStart := time.Now()
for catID, promptTmpl := range prompts {
wg.Add(1)
go func(catID int, promptTmpl string) {
defer wg.Done()
catName := lib.CategoryName(catID)
prompt := extractionPreamble + "\n" + strings.ReplaceAll(promptTmpl, "{{MARKDOWN}}", markdown)
msgs := []map[string]interface{}{
{"role": "user", "content": prompt},
}
resp, err := lib.CallFireworks(textModel, msgs, 4096)
if err != nil {
fmt.Printf(" [%s] API error: %v\n", catName, err)
return
}
resp = strings.TrimSpace(resp)
if resp == "null" || resp == "" {
fmt.Printf(" [%s] → null\n", catName)
return
}
var entries []extractedEntry
if err := json.Unmarshal([]byte(resp), &entries); err != nil {
var single extractedEntry
if err2 := json.Unmarshal([]byte(resp), &single); err2 == nil && single.Summary != "" {
entries = []extractedEntry{single}
} else {
fmt.Printf(" [%s] → parse error: %v\n Response: %s\n", catName, err, resp[:min(200, len(resp))])
return
}
}
if len(entries) == 0 {
fmt.Printf(" [%s] → empty array\n", catName)
return
}
fmt.Printf(" [%s] → %d entries\n", catName, len(entries))
mu.Lock()
results = append(results, catResult{Category: catID, Entries: entries})
mu.Unlock()
}(catID, promptTmpl)
}
wg.Wait()
fmt.Printf("Extraction done in %.1fs: %d categories\n", time.Since(extractStart).Seconds(), len(results))
// 5. Create entries
var totalEntries int
for _, r := range results {
for _, e := range r.Entries {
dataMap := map[string]interface{}{"source_doc_id": docID}
for k, v := range e.Data {
dataMap[k] = v
}
if len(e.SourceSpans) > 0 {
dataMap["source_spans"] = e.SourceSpans
}
if e.SummaryTranslated != "" {
dataMap["summary_translated"] = e.SummaryTranslated
}
dataJSON, _ := json.Marshal(dataMap)
ts := now
if e.Timestamp != "" {
for _, layout := range []string{"2006-01-02", "02.01.2006", "01/02/2006"} {
if t, err := time.Parse(layout, e.Timestamp); err == nil {
ts = t.Unix()
break
}
}
}
entry := &lib.Entry{
DossierID: dossierID,
ParentID: docID,
Category: r.Category,
Type: e.Type,
Value: e.Value,
Summary: e.Summary,
SearchKey: e.SearchKey,
Timestamp: ts,
Data: string(dataJSON),
}
lib.EntryWrite("", entry)
totalEntries++
}
}
fmt.Printf("Created %d entries under doc %s\n", totalEntries, docID)
// 6. Show results
fmt.Println("\n=== Results ===")
for _, r := range results {
catName := lib.CategoryName(r.Category)
for _, e := range r.Entries {
spans := ""
if len(e.SourceSpans) > 0 {
spans = fmt.Sprintf(" spans=%d", len(e.SourceSpans))
}
trans := ""
if e.SummaryTranslated != "" {
trans = fmt.Sprintf(" [%s]", e.SummaryTranslated)
}
fmt.Printf(" [%s] Type=%s Summary=%s%s%s\n", catName, e.Type, e.Summary, trans, spans)
}
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}