283 lines
8.4 KiB
Go
283 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"inou/lib"
|
|
"log"
|
|
"os"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
)
|
|
|
|
type extractedEntry struct {
|
|
Type string `json:"type"`
|
|
Value string `json:"value"`
|
|
Summary string `json:"summary"`
|
|
SummaryTranslated string `json:"summary_translated,omitempty"`
|
|
SearchKey string `json:"search_key,omitempty"`
|
|
Timestamp string `json:"timestamp,omitempty"`
|
|
Data map[string]interface{} `json:"data"`
|
|
SourceSpans []sourceSpan `json:"source_spans,omitempty"`
|
|
}
|
|
|
|
type sourceSpan struct {
|
|
Start string `json:"start"`
|
|
End string `json:"end"`
|
|
}
|
|
|
|
var extractionPreamble = `IMPORTANT RULES (apply to all entries you return):
|
|
- Do NOT translate. Keep ALL text values (summary, value, data fields) in the ORIGINAL language of the document.
|
|
- For each entry, include "source_spans": an array of {"start": "...", "end": "..."} where start/end are the VERBATIM first and last 5-8 words of the relevant passage(s) in the source markdown. This is used to highlight the source text. Multiple spans are allowed.
|
|
- For each entry, include "search_key": a short normalized deduplication key in English lowercase. Format: "thing:qualifier:YYYY-MM" or "thing:qualifier" for undated facts. Examples: "surgery:vp-shunt:2020-07", "device:ommaya-reservoir:2020-04", "diagnosis:hydrocephalus", "provider:peraud:ulm". Same real-world fact across different documents MUST produce the same key.
|
|
`
|
|
|
|
// loadExtractionPrompts discovers all extract_*.md files and returns {categoryID: prompt content}.
|
|
func loadExtractionPrompts() map[int]string {
|
|
pattern := filepath.Join(lib.TrackerPromptsDir(), "extract_*.md")
|
|
files, _ := filepath.Glob(pattern)
|
|
prompts := make(map[int]string)
|
|
for _, f := range files {
|
|
base := filepath.Base(f)
|
|
name := strings.TrimPrefix(base, "extract_")
|
|
name = strings.TrimSuffix(name, ".md")
|
|
catID, ok := lib.CategoryFromString[name]
|
|
if !ok {
|
|
fmt.Printf("Unknown category in prompt file: %s\n", base)
|
|
continue
|
|
}
|
|
data, err := os.ReadFile(f)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
prompts[catID] = string(data)
|
|
}
|
|
return prompts
|
|
}
|
|
|
|
const (
|
|
visionModel = "accounts/fireworks/models/qwen3-vl-30b-a3b-instruct"
|
|
textModel = "accounts/fireworks/models/qwen3-vl-30b-a3b-instruct"
|
|
)
|
|
|
|
var ocrPrompt = `You are a medical document OCR system. Produce a faithful markdown transcription of this document.
|
|
|
|
The images are sequential pages of the same document. Process them in order: page 1 first, then page 2, etc.
|
|
|
|
Rules:
|
|
- Read each page top-to-bottom, left-to-right
|
|
- Preserve ALL text, dates, values, names, addresses, and structure
|
|
- Translate nothing — keep the original language
|
|
- Use markdown headers, lists, and formatting to reflect the document structure
|
|
- For tables, use markdown tables. Preserve numeric values exactly.
|
|
- Be complete — do not skip or summarize anything
|
|
- Do not describe visual elements (logos, signatures) — only transcribe text
|
|
- For handwritten text, transcribe as accurately as possible. Mark uncertain readings with [?]`
|
|
|
|
func main() {
|
|
if len(os.Args) < 3 {
|
|
fmt.Fprintf(os.Stderr, "Usage: test-doc-import <dossierID> <pdf-path>\n")
|
|
os.Exit(1)
|
|
}
|
|
dossierID := os.Args[1]
|
|
pdfPath := os.Args[2]
|
|
fileName := filepath.Base(pdfPath)
|
|
|
|
if err := lib.Init(); err != nil {
|
|
log.Fatalf("lib.Init: %v", err)
|
|
}
|
|
lib.ConfigInit()
|
|
lib.InitPrompts("tracker_prompts")
|
|
|
|
fmt.Printf("Prompts dir: %s\n", lib.TrackerPromptsDir())
|
|
|
|
// 1. Convert PDF to PNG pages
|
|
tempDir, _ := os.MkdirTemp("", "doc-import-*")
|
|
defer os.RemoveAll(tempDir)
|
|
|
|
prefix := filepath.Join(tempDir, "page")
|
|
cmd := exec.Command("pdftoppm", "-png", "-r", "200", pdfPath, prefix)
|
|
if out, err := cmd.CombinedOutput(); err != nil {
|
|
log.Fatalf("pdftoppm: %v: %s", err, out)
|
|
}
|
|
|
|
pageFiles, _ := filepath.Glob(prefix + "*.png")
|
|
sort.Strings(pageFiles)
|
|
fmt.Printf("%d pages converted\n", len(pageFiles))
|
|
|
|
// 2. OCR
|
|
content := []interface{}{
|
|
map[string]string{"type": "text", "text": ocrPrompt},
|
|
}
|
|
for _, pf := range pageFiles {
|
|
imgBytes, _ := os.ReadFile(pf)
|
|
b64 := base64.StdEncoding.EncodeToString(imgBytes)
|
|
content = append(content, map[string]interface{}{
|
|
"type": "image_url",
|
|
"image_url": map[string]string{
|
|
"url": "data:image/png;base64," + b64,
|
|
},
|
|
})
|
|
}
|
|
|
|
fmt.Printf("Calling OCR...\n")
|
|
start := time.Now()
|
|
markdown, err := lib.CallFireworks(visionModel, []map[string]interface{}{
|
|
{"role": "user", "content": content},
|
|
}, 16384)
|
|
if err != nil {
|
|
log.Fatalf("OCR: %v", err)
|
|
}
|
|
fmt.Printf("OCR done: %d chars in %.1fs\n", len(markdown), time.Since(start).Seconds())
|
|
|
|
// 3. Create document entry
|
|
now := time.Now().Unix()
|
|
docData := map[string]interface{}{
|
|
"markdown": markdown,
|
|
"pages": len(pageFiles),
|
|
}
|
|
docDataJSON, _ := json.Marshal(docData)
|
|
docEntry := &lib.Entry{
|
|
DossierID: dossierID,
|
|
Category: lib.CategoryDocument,
|
|
Type: "pdf",
|
|
Value: fileName,
|
|
Timestamp: now,
|
|
Data: string(docDataJSON),
|
|
}
|
|
lib.EntryWrite("", docEntry)
|
|
docID := docEntry.EntryID
|
|
fmt.Printf("Document entry: %s\n", docID)
|
|
|
|
// 4. Fan out extraction
|
|
type catResult struct {
|
|
Category int
|
|
Entries []extractedEntry
|
|
}
|
|
var mu sync.Mutex
|
|
var results []catResult
|
|
var wg sync.WaitGroup
|
|
|
|
prompts := loadExtractionPrompts()
|
|
fmt.Printf("Starting %d extraction calls...\n", len(prompts))
|
|
extractStart := time.Now()
|
|
|
|
for catID, promptTmpl := range prompts {
|
|
wg.Add(1)
|
|
go func(catID int, promptTmpl string) {
|
|
defer wg.Done()
|
|
catName := lib.CategoryName(catID)
|
|
|
|
prompt := extractionPreamble + "\n" + strings.ReplaceAll(promptTmpl, "{{MARKDOWN}}", markdown)
|
|
|
|
msgs := []map[string]interface{}{
|
|
{"role": "user", "content": prompt},
|
|
}
|
|
resp, err := lib.CallFireworks(textModel, msgs, 4096)
|
|
if err != nil {
|
|
fmt.Printf(" [%s] API error: %v\n", catName, err)
|
|
return
|
|
}
|
|
resp = strings.TrimSpace(resp)
|
|
if resp == "null" || resp == "" {
|
|
fmt.Printf(" [%s] → null\n", catName)
|
|
return
|
|
}
|
|
|
|
var entries []extractedEntry
|
|
if err := json.Unmarshal([]byte(resp), &entries); err != nil {
|
|
var single extractedEntry
|
|
if err2 := json.Unmarshal([]byte(resp), &single); err2 == nil && single.Summary != "" {
|
|
entries = []extractedEntry{single}
|
|
} else {
|
|
fmt.Printf(" [%s] → parse error: %v\n Response: %s\n", catName, err, resp[:min(200, len(resp))])
|
|
return
|
|
}
|
|
}
|
|
if len(entries) == 0 {
|
|
fmt.Printf(" [%s] → empty array\n", catName)
|
|
return
|
|
}
|
|
fmt.Printf(" [%s] → %d entries\n", catName, len(entries))
|
|
|
|
mu.Lock()
|
|
results = append(results, catResult{Category: catID, Entries: entries})
|
|
mu.Unlock()
|
|
}(catID, promptTmpl)
|
|
}
|
|
wg.Wait()
|
|
fmt.Printf("Extraction done in %.1fs: %d categories\n", time.Since(extractStart).Seconds(), len(results))
|
|
|
|
// 5. Create entries
|
|
var totalEntries int
|
|
for _, r := range results {
|
|
for _, e := range r.Entries {
|
|
dataMap := map[string]interface{}{"source_doc_id": docID}
|
|
for k, v := range e.Data {
|
|
dataMap[k] = v
|
|
}
|
|
if len(e.SourceSpans) > 0 {
|
|
dataMap["source_spans"] = e.SourceSpans
|
|
}
|
|
if e.SummaryTranslated != "" {
|
|
dataMap["summary_translated"] = e.SummaryTranslated
|
|
}
|
|
dataJSON, _ := json.Marshal(dataMap)
|
|
|
|
ts := now
|
|
if e.Timestamp != "" {
|
|
for _, layout := range []string{"2006-01-02", "02.01.2006", "01/02/2006"} {
|
|
if t, err := time.Parse(layout, e.Timestamp); err == nil {
|
|
ts = t.Unix()
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
entry := &lib.Entry{
|
|
DossierID: dossierID,
|
|
ParentID: docID,
|
|
Category: r.Category,
|
|
Type: e.Type,
|
|
Value: e.Value,
|
|
Summary: e.Summary,
|
|
SearchKey: e.SearchKey,
|
|
Timestamp: ts,
|
|
Data: string(dataJSON),
|
|
}
|
|
lib.EntryWrite("", entry)
|
|
totalEntries++
|
|
}
|
|
}
|
|
fmt.Printf("Created %d entries under doc %s\n", totalEntries, docID)
|
|
|
|
// 6. Show results
|
|
fmt.Println("\n=== Results ===")
|
|
for _, r := range results {
|
|
catName := lib.CategoryName(r.Category)
|
|
for _, e := range r.Entries {
|
|
spans := ""
|
|
if len(e.SourceSpans) > 0 {
|
|
spans = fmt.Sprintf(" spans=%d", len(e.SourceSpans))
|
|
}
|
|
trans := ""
|
|
if e.SummaryTranslated != "" {
|
|
trans = fmt.Sprintf(" [%s]", e.SummaryTranslated)
|
|
}
|
|
fmt.Printf(" [%s] Type=%s Summary=%s%s%s\n", catName, e.Type, e.Summary, trans, spans)
|
|
}
|
|
}
|
|
}
|
|
|
|
func min(a, b int) int {
|
|
if a < b {
|
|
return a
|
|
}
|
|
return b
|
|
}
|