218 lines
5.8 KiB
Go
218 lines
5.8 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/binary"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"time"
|
|
|
|
"dealroom/internal/extract"
|
|
"dealroom/internal/fireworks"
|
|
)
|
|
|
|
const matchThreshold = 0.72
|
|
|
|
type ExtractionJob struct {
|
|
ResponseID string
|
|
FilePath string // absolute path to uploaded file (or "" for statements)
|
|
DealID string
|
|
}
|
|
|
|
type Extractor struct {
|
|
db *sql.DB
|
|
fw *fireworks.Client
|
|
jobs chan ExtractionJob
|
|
}
|
|
|
|
func NewExtractor(db *sql.DB, fw *fireworks.Client) *Extractor {
|
|
return &Extractor{
|
|
db: db,
|
|
fw: fw,
|
|
jobs: make(chan ExtractionJob, 100),
|
|
}
|
|
}
|
|
|
|
func (e *Extractor) Start() {
|
|
for i := 0; i < 2; i++ {
|
|
go e.worker(i)
|
|
}
|
|
log.Println("Extraction worker started (2 goroutines)")
|
|
}
|
|
|
|
func (e *Extractor) Enqueue(job ExtractionJob) {
|
|
e.jobs <- job
|
|
}
|
|
|
|
func (e *Extractor) worker(id int) {
|
|
for job := range e.jobs {
|
|
e.process(id, job)
|
|
}
|
|
}
|
|
|
|
func (e *Extractor) process(workerID int, job ExtractionJob) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
|
defer cancel()
|
|
|
|
log.Printf("[extractor-%d] Processing response %s (deal=%s, file=%s)", workerID, job.ResponseID, job.DealID, job.FilePath)
|
|
|
|
// Set status to processing
|
|
e.db.Exec("UPDATE responses SET extraction_status = 'processing', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
|
|
|
|
var body string
|
|
|
|
if job.FilePath != "" {
|
|
// Document: extract text from file
|
|
md, err := e.extractFile(ctx, job)
|
|
if err != nil {
|
|
log.Printf("[extractor-%d] Extraction failed for %s: %v", workerID, job.ResponseID, err)
|
|
e.db.Exec("UPDATE responses SET extraction_status = 'failed', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
|
|
return
|
|
}
|
|
body = md
|
|
// Update response body and status
|
|
e.db.Exec("UPDATE responses SET body = ?, extraction_status = 'done', updated_at = datetime('now') WHERE id = ?", body, job.ResponseID)
|
|
} else {
|
|
// Statement: body is already set, just mark done
|
|
e.db.Exec("UPDATE responses SET extraction_status = 'done', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
|
|
// Load existing body
|
|
e.db.QueryRow("SELECT body FROM responses WHERE id = ?", job.ResponseID).Scan(&body)
|
|
}
|
|
|
|
if body == "" {
|
|
log.Printf("[extractor-%d] Empty body for response %s, skipping chunk+match", workerID, job.ResponseID)
|
|
return
|
|
}
|
|
|
|
// Chunk
|
|
chunks := extract.ChunkMarkdown(body)
|
|
if len(chunks) == 0 {
|
|
log.Printf("[extractor-%d] No chunks produced for response %s", workerID, job.ResponseID)
|
|
return
|
|
}
|
|
|
|
// Embed chunks
|
|
chunkVectors, err := e.fw.EmbedText(ctx, chunks)
|
|
if err != nil {
|
|
log.Printf("[extractor-%d] Embedding failed for %s: %v", workerID, job.ResponseID, err)
|
|
e.db.Exec("UPDATE responses SET extraction_status = 'failed', updated_at = datetime('now') WHERE id = ?", job.ResponseID)
|
|
return
|
|
}
|
|
|
|
// Store chunks
|
|
chunkIDs := make([]string, len(chunks))
|
|
for i, chunk := range chunks {
|
|
chunkID := generateID("chunk")
|
|
chunkIDs[i] = chunkID
|
|
vecBytes := float32sToBytes(chunkVectors[i])
|
|
e.db.Exec("INSERT INTO response_chunks (id, response_id, chunk_index, text, vector) VALUES (?, ?, ?, ?, ?)",
|
|
chunkID, job.ResponseID, i, chunk, vecBytes)
|
|
}
|
|
|
|
// Match against open requests in this deal
|
|
linkCount := e.matchRequests(ctx, job.DealID, job.ResponseID, chunkIDs, chunkVectors)
|
|
|
|
log.Printf("[extractor-%d] Response %s: %d chunks, %d request links auto-created", workerID, job.ResponseID, len(chunks), linkCount)
|
|
}
|
|
|
|
func (e *Extractor) extractFile(ctx context.Context, job ExtractionJob) (string, error) {
|
|
if extract.IsXLSX(job.FilePath) {
|
|
// XLSX: extract text dump, send as text to LLM
|
|
text, err := extract.XLSXToText(job.FilePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("xlsx extract: %w", err)
|
|
}
|
|
md, err := e.fw.ExtractTextToMarkdown(ctx, text, job.FilePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("xlsx to markdown: %w", err)
|
|
}
|
|
return md, nil
|
|
}
|
|
|
|
// PDF or image
|
|
images, err := extract.FileToImages(job.FilePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("file to images: %w", err)
|
|
}
|
|
if len(images) == 0 {
|
|
return "", fmt.Errorf("no images extracted from file")
|
|
}
|
|
|
|
md, err := e.fw.ExtractToMarkdown(ctx, images, job.FilePath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("vision extraction: %w", err)
|
|
}
|
|
return md, nil
|
|
}
|
|
|
|
func (e *Extractor) matchRequests(ctx context.Context, dealID, responseID string, chunkIDs []string, chunkVectors [][]float32) int {
|
|
// Load all requests for this deal
|
|
rows, err := e.db.Query("SELECT id, description FROM diligence_requests WHERE deal_id = ?", dealID)
|
|
if err != nil {
|
|
log.Printf("[extractor] Failed to load requests for deal %s: %v", dealID, err)
|
|
return 0
|
|
}
|
|
defer rows.Close()
|
|
|
|
type reqInfo struct {
|
|
id, desc string
|
|
}
|
|
var reqs []reqInfo
|
|
for rows.Next() {
|
|
var r reqInfo
|
|
rows.Scan(&r.id, &r.desc)
|
|
reqs = append(reqs, r)
|
|
}
|
|
|
|
if len(reqs) == 0 {
|
|
return 0
|
|
}
|
|
|
|
// Embed request descriptions
|
|
descs := make([]string, len(reqs))
|
|
for i, r := range reqs {
|
|
descs[i] = r.desc
|
|
}
|
|
reqVectors, err := e.fw.EmbedText(ctx, descs)
|
|
if err != nil {
|
|
log.Printf("[extractor] Failed to embed request descriptions: %v", err)
|
|
return 0
|
|
}
|
|
|
|
// Match each (chunk, request) pair
|
|
linkCount := 0
|
|
for ci, chunkVec := range chunkVectors {
|
|
for ri, reqVec := range reqVectors {
|
|
sim := fireworks.CosineSimilarity(chunkVec, reqVec)
|
|
if sim >= matchThreshold {
|
|
_, err := e.db.Exec(
|
|
"INSERT OR IGNORE INTO request_links (request_id, response_id, chunk_id, confidence, auto_linked, confirmed) VALUES (?, ?, ?, ?, 1, 0)",
|
|
reqs[ri].id, responseID, chunkIDs[ci], sim)
|
|
if err == nil {
|
|
linkCount++
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return linkCount
|
|
}
|
|
|
|
func float32sToBytes(fs []float32) []byte {
|
|
buf := make([]byte, len(fs)*4)
|
|
for i, f := range fs {
|
|
binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(f))
|
|
}
|
|
return buf
|
|
}
|
|
|
|
func generateID(prefix string) string {
|
|
b := make([]byte, 8)
|
|
rand.Read(b)
|
|
return prefix + "-" + hex.EncodeToString(b)
|
|
}
|