dealroom/internal/worker/extractor.go

218 lines
5.8 KiB
Go

package worker
import (
"context"
"crypto/rand"
"database/sql"
"encoding/binary"
"encoding/hex"
"fmt"
"log"
"math"
"time"
"dealroom/internal/extract"
"dealroom/internal/fireworks"
)
const matchThreshold = 0.72
type ExtractionJob struct {
ResponseID string
FilePath string // absolute path to uploaded file (or "" for statements)
DealID string
}
type Extractor struct {
db *sql.DB
fw *fireworks.Client
jobs chan ExtractionJob
}
func NewExtractor(db *sql.DB, fw *fireworks.Client) *Extractor {
return &Extractor{
db: db,
fw: fw,
jobs: make(chan ExtractionJob, 100),
}
}
func (e *Extractor) Start() {
for i := 0; i < 2; i++ {
go e.worker(i)
}
log.Println("Extraction worker started (2 goroutines)")
}
func (e *Extractor) Enqueue(job ExtractionJob) {
e.jobs <- job
}
func (e *Extractor) worker(id int) {
for job := range e.jobs {
e.process(id, job)
}
}
func (e *Extractor) process(workerID int, job ExtractionJob) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
log.Printf("[extractor-%d] Processing response %s (deal=%s, file=%s)", workerID, job.ResponseID, job.DealID, job.FilePath)
// Set status to processing
e.db.Exec("UPDATE responses SET extraction_status = 'processing', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
var body string
if job.FilePath != "" {
// Document: extract text from file
md, err := e.extractFile(ctx, job)
if err != nil {
log.Printf("[extractor-%d] Extraction failed for %s: %v", workerID, job.ResponseID, err)
e.db.Exec("UPDATE responses SET extraction_status = 'failed', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
return
}
body = md
// Update response body and status
e.db.Exec("UPDATE responses SET body = ?, extraction_status = 'done', updated_at = datetime('now') WHERE id = ?", body, job.ResponseID)
} else {
// Statement: body is already set, just mark done
e.db.Exec("UPDATE responses SET extraction_status = 'done', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
// Load existing body
e.db.QueryRow("SELECT body FROM responses WHERE id = ?", job.ResponseID).Scan(&body)
}
if body == "" {
log.Printf("[extractor-%d] Empty body for response %s, skipping chunk+match", workerID, job.ResponseID)
return
}
// Chunk
chunks := extract.ChunkMarkdown(body)
if len(chunks) == 0 {
log.Printf("[extractor-%d] No chunks produced for response %s", workerID, job.ResponseID)
return
}
// Embed chunks
chunkVectors, err := e.fw.EmbedText(ctx, chunks)
if err != nil {
log.Printf("[extractor-%d] Embedding failed for %s: %v", workerID, job.ResponseID, err)
e.db.Exec("UPDATE responses SET extraction_status = 'failed', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
return
}
// Store chunks
chunkIDs := make([]string, len(chunks))
for i, chunk := range chunks {
chunkID := generateID("chunk")
chunkIDs[i] = chunkID
vecBytes := float32sToBytes(chunkVectors[i])
e.db.Exec("INSERT INTO response_chunks (id, response_id, chunk_index, text, vector) VALUES (?, ?, ?, ?, ?)",
chunkID, job.ResponseID, i, chunk, vecBytes)
}
// Match against open requests in this deal
linkCount := e.matchRequests(ctx, job.DealID, job.ResponseID, chunkIDs, chunkVectors)
log.Printf("[extractor-%d] Response %s: %d chunks, %d request links auto-created", workerID, job.ResponseID, len(chunks), linkCount)
}
func (e *Extractor) extractFile(ctx context.Context, job ExtractionJob) (string, error) {
if extract.IsXLSX(job.FilePath) {
// XLSX: extract text dump, send as text to LLM
text, err := extract.XLSXToText(job.FilePath)
if err != nil {
return "", fmt.Errorf("xlsx extract: %w", err)
}
md, err := e.fw.ExtractTextToMarkdown(ctx, text, job.FilePath)
if err != nil {
return "", fmt.Errorf("xlsx to markdown: %w", err)
}
return md, nil
}
// PDF or image
images, err := extract.FileToImages(job.FilePath)
if err != nil {
return "", fmt.Errorf("file to images: %w", err)
}
if len(images) == 0 {
return "", fmt.Errorf("no images extracted from file")
}
md, err := e.fw.ExtractToMarkdown(ctx, images, job.FilePath)
if err != nil {
return "", fmt.Errorf("vision extraction: %w", err)
}
return md, nil
}
func (e *Extractor) matchRequests(ctx context.Context, dealID, responseID string, chunkIDs []string, chunkVectors [][]float32) int {
// Load all requests for this deal
rows, err := e.db.Query("SELECT id, description FROM diligence_requests WHERE deal_id = ?", dealID)
if err != nil {
log.Printf("[extractor] Failed to load requests for deal %s: %v", dealID, err)
return 0
}
defer rows.Close()
type reqInfo struct {
id, desc string
}
var reqs []reqInfo
for rows.Next() {
var r reqInfo
rows.Scan(&r.id, &r.desc)
reqs = append(reqs, r)
}
if len(reqs) == 0 {
return 0
}
// Embed request descriptions
descs := make([]string, len(reqs))
for i, r := range reqs {
descs[i] = r.desc
}
reqVectors, err := e.fw.EmbedText(ctx, descs)
if err != nil {
log.Printf("[extractor] Failed to embed request descriptions: %v", err)
return 0
}
// Match each (chunk, request) pair
linkCount := 0
for ci, chunkVec := range chunkVectors {
for ri, reqVec := range reqVectors {
sim := fireworks.CosineSimilarity(chunkVec, reqVec)
if sim >= matchThreshold {
_, err := e.db.Exec(
"INSERT OR IGNORE INTO request_links (request_id, response_id, chunk_id, confidence, auto_linked, confirmed) VALUES (?, ?, ?, ?, 1, 0)",
reqs[ri].id, responseID, chunkIDs[ci], sim)
if err == nil {
linkCount++
}
}
}
}
return linkCount
}
func float32sToBytes(fs []float32) []byte {
buf := make([]byte, len(fs)*4)
for i, f := range fs {
binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(f))
}
return buf
}
func generateID(prefix string) string {
b := make([]byte, 8)
rand.Read(b)
return prefix + "-" + hex.EncodeToString(b)
}