diff --git a/internal/db/migrate.go b/internal/db/migrate.go index 1a7bef8..f1a52b3 100644 --- a/internal/db/migrate.go +++ b/internal/db/migrate.go @@ -24,6 +24,10 @@ func Migrate(db *sql.DB) error { createFolderAccess, createFileComments, createContactDeals, + createResponses, + createResponseChunks, + createRequestLinks, + createAssignmentRules, } for i, m := range migrations { @@ -242,6 +246,55 @@ CREATE TABLE IF NOT EXISTS file_comments ( created_at DATETIME DEFAULT CURRENT_TIMESTAMP );` +const createResponses = ` +CREATE TABLE IF NOT EXISTS responses ( + id TEXT PRIMARY KEY, + deal_id TEXT NOT NULL, + type TEXT NOT NULL CHECK (type IN ('document','statement')), + title TEXT NOT NULL, + body TEXT DEFAULT '', + file_id TEXT DEFAULT '', + extraction_status TEXT DEFAULT 'pending' + CHECK (extraction_status IN ('pending','processing','done','failed')), + created_by TEXT DEFAULT '', + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (deal_id) REFERENCES deals(id) +);` + +const createResponseChunks = ` +CREATE TABLE IF NOT EXISTS response_chunks ( + id TEXT PRIMARY KEY, + response_id TEXT NOT NULL, + chunk_index INTEGER NOT NULL, + text TEXT NOT NULL, + vector BLOB NOT NULL, + FOREIGN KEY (response_id) REFERENCES responses(id) +);` + +const createRequestLinks = ` +CREATE TABLE IF NOT EXISTS request_links ( + request_id TEXT NOT NULL, + response_id TEXT NOT NULL, + chunk_id TEXT NOT NULL, + confidence REAL NOT NULL, + auto_linked BOOLEAN DEFAULT 1, + confirmed BOOLEAN DEFAULT 0, + confirmed_by TEXT DEFAULT '', + confirmed_at DATETIME, + PRIMARY KEY (request_id, response_id, chunk_id) +);` + +const createAssignmentRules = ` +CREATE TABLE IF NOT EXISTS assignment_rules ( + id TEXT PRIMARY KEY, + deal_id TEXT NOT NULL, + keyword TEXT NOT NULL, + assignee_id TEXT NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (deal_id) REFERENCES deals(id) +);` + // Additive migrations - each statement is run individually, errors ignored (for already-existing columns) var additiveMigrationStmts = []string{ // Section 1: org_type @@ -262,6 +315,10 @@ var additiveMigrationStmts = []string{ // Section 13: analytics per-buyer `ALTER TABLE deal_activity ADD COLUMN buyer_group TEXT DEFAULT ''`, `ALTER TABLE deal_activity ADD COLUMN time_spent_seconds INTEGER DEFAULT 0`, + // Responses feature: assignee + status on requests, response_id on files + `ALTER TABLE diligence_requests ADD COLUMN assignee_id TEXT DEFAULT ''`, + `ALTER TABLE diligence_requests ADD COLUMN status TEXT DEFAULT 'open'`, + `ALTER TABLE files ADD COLUMN response_id TEXT DEFAULT ''`, } // fixDealStageConstraint recreates the deals table if it was created with the diff --git a/internal/extract/chunker.go b/internal/extract/chunker.go new file mode 100644 index 0000000..64ccc10 --- /dev/null +++ b/internal/extract/chunker.go @@ -0,0 +1,140 @@ +package extract + +import ( + "strings" +) + +const ( + maxChunkChars = 1600 + overlapChars = 80 + minChunkChars = 50 +) + +// ChunkMarkdown splits markdown text into overlapping chunks for embedding. +func ChunkMarkdown(text string) []string { + if strings.TrimSpace(text) == "" { + return nil + } + + // Split on headings first + sections := splitOnHeadings(text) + + var chunks []string + for _, section := range sections { + section = strings.TrimSpace(section) + if len(section) < minChunkChars { + continue + } + + if len(section) <= maxChunkChars { + chunks = append(chunks, section) + continue + } + + // Split further at paragraph breaks + paragraphs := strings.Split(section, "\n\n") + var current strings.Builder + for _, para := range paragraphs { + para = strings.TrimSpace(para) + if para == "" { + continue + } + + if current.Len()+len(para)+2 > maxChunkChars && current.Len() > 0 { + chunks = append(chunks, current.String()) + current.Reset() + } + + if len(para) > maxChunkChars { + // Flush current buffer first + if current.Len() > 0 { + chunks = append(chunks, current.String()) + current.Reset() + } + // Split at sentence boundaries + sentences := splitSentences(para) + for _, sent := range sentences { + if current.Len()+len(sent)+1 > maxChunkChars && current.Len() > 0 { + chunks = append(chunks, current.String()) + current.Reset() + } + if current.Len() > 0 { + current.WriteString(" ") + } + current.WriteString(sent) + } + } else { + if current.Len() > 0 { + current.WriteString("\n\n") + } + current.WriteString(para) + } + } + if current.Len() > 0 { + chunks = append(chunks, current.String()) + } + } + + // Apply overlap + if len(chunks) > 1 { + overlapped := make([]string, len(chunks)) + overlapped[0] = chunks[0] + for i := 1; i < len(chunks); i++ { + prev := chunks[i-1] + overlap := prev + if len(overlap) > overlapChars { + overlap = overlap[len(overlap)-overlapChars:] + } + overlapped[i] = overlap + " " + chunks[i] + } + chunks = overlapped + } + + // Filter out too-short chunks + var result []string + for _, c := range chunks { + if len(strings.TrimSpace(c)) >= minChunkChars { + result = append(result, c) + } + } + + return result +} + +func splitOnHeadings(text string) []string { + lines := strings.Split(text, "\n") + var sections []string + var current strings.Builder + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if (strings.HasPrefix(trimmed, "## ") || strings.HasPrefix(trimmed, "### ")) && current.Len() > 0 { + sections = append(sections, current.String()) + current.Reset() + } + current.WriteString(line) + current.WriteString("\n") + } + if current.Len() > 0 { + sections = append(sections, current.String()) + } + + return sections +} + +func splitSentences(text string) []string { + // Split on ". " while preserving the period + parts := strings.Split(text, ". ") + var sentences []string + for i, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if i < len(parts)-1 { + p += "." + } + sentences = append(sentences, p) + } + return sentences +} diff --git a/internal/extract/pdf.go b/internal/extract/pdf.go new file mode 100644 index 0000000..004ef76 --- /dev/null +++ b/internal/extract/pdf.go @@ -0,0 +1,134 @@ +package extract + +import ( + "bytes" + "encoding/base64" + "fmt" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + + "github.com/xuri/excelize/v2" +) + +// FileToImages converts a file to base64-encoded JPEG images for vision extraction. +// For images (jpg/png), returns the base64 directly. +// For XLSX, returns nil (caller should use XLSXToText instead). +// For PDF, uses pdftoppm to rasterise pages. +func FileToImages(path string) ([]string, error) { + ext := strings.ToLower(filepath.Ext(path)) + + switch ext { + case ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp": + return imageToBase64(path) + case ".xlsx", ".xls": + return nil, nil // caller uses XLSXToText + case ".pdf": + return pdfToImages(path) + default: + // Try pdftoppm anyway; if it fails, return empty + imgs, err := pdfToImages(path) + if err != nil || len(imgs) == 0 { + return nil, fmt.Errorf("unsupported file type %s", ext) + } + return imgs, nil + } +} + +// XLSXToText extracts all sheets from an XLSX file as markdown tables. +func XLSXToText(path string) (string, error) { + f, err := excelize.OpenFile(path) + if err != nil { + return "", fmt.Errorf("open xlsx: %w", err) + } + defer f.Close() + + var buf bytes.Buffer + for _, sheetName := range f.GetSheetList() { + rows, err := f.GetRows(sheetName) + if err != nil { + continue + } + if len(rows) == 0 { + continue + } + + buf.WriteString(fmt.Sprintf("## %s\n\n", sheetName)) + + // Write as markdown table + if len(rows) > 0 { + // Header row + buf.WriteString("| " + strings.Join(rows[0], " | ") + " |\n") + buf.WriteString("|" + strings.Repeat(" --- |", len(rows[0])) + "\n") + // Data rows + for _, row := range rows[1:] { + // Pad row if shorter than header + for len(row) < len(rows[0]) { + row = append(row, "") + } + buf.WriteString("| " + strings.Join(row, " | ") + " |\n") + } + } + buf.WriteString("\n") + } + + return buf.String(), nil +} + +func imageToBase64(path string) ([]string, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + return []string{base64.StdEncoding.EncodeToString(data)}, nil +} + +func pdfToImages(path string) ([]string, error) { + tmpDir, err := os.MkdirTemp("", "pdf2img-") + if err != nil { + return nil, err + } + defer os.RemoveAll(tmpDir) + + prefix := filepath.Join(tmpDir, "page") + cmd := exec.Command("pdftoppm", "-jpeg", "-r", "150", path, prefix) + if out, err := cmd.CombinedOutput(); err != nil { + return nil, fmt.Errorf("pdftoppm failed: %w: %s", err, string(out)) + } + + // Read generated files in sorted order + entries, err := os.ReadDir(tmpDir) + if err != nil { + return nil, err + } + + var names []string + for _, e := range entries { + if strings.HasSuffix(e.Name(), ".jpg") { + names = append(names, e.Name()) + } + } + sort.Strings(names) + + var images []string + for _, name := range names { + data, err := os.ReadFile(filepath.Join(tmpDir, name)) + if err != nil { + continue + } + images = append(images, base64.StdEncoding.EncodeToString(data)) + } + + if len(images) == 0 { + return nil, fmt.Errorf("pdftoppm produced no images") + } + return images, nil +} + +// IsXLSX returns true if the file is an Excel file. +func IsXLSX(path string) bool { + ext := strings.ToLower(filepath.Ext(path)) + return ext == ".xlsx" || ext == ".xls" +} diff --git a/internal/fireworks/client.go b/internal/fireworks/client.go new file mode 100644 index 0000000..e5b9644 --- /dev/null +++ b/internal/fireworks/client.go @@ -0,0 +1,214 @@ +package fireworks + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "math" + "net/http" +) + +const ( + baseURL = "https://api.fireworks.ai/inference/v1" + apiKey = "fw_RVcDe4c6mN4utKLsgA7hTm" + visionModel = "accounts/fireworks/models/llama-v3p2-90b-vision-instruct" + embeddingModel = "nomic-ai/nomic-embed-text-v1.5" + maxImagesPerCall = 10 + maxTextsPerBatch = 50 +) + +type Client struct { + http *http.Client +} + +func NewClient() *Client { + return &Client{http: &http.Client{}} +} + +// ExtractToMarkdown sends base64 images to the vision model and returns extracted markdown. +// For XLSX text content, pass nil images and set textContent instead. +func (c *Client) ExtractToMarkdown(ctx context.Context, imageBase64 []string, filename string) (string, error) { + if len(imageBase64) == 0 { + return "", fmt.Errorf("no images provided") + } + + var fullMarkdown string + + // Batch images into groups of maxImagesPerCall + for i := 0; i < len(imageBase64); i += maxImagesPerCall { + end := i + maxImagesPerCall + if end > len(imageBase64) { + end = len(imageBase64) + } + batch := imageBase64[i:end] + + content := []map[string]interface{}{ + {"type": "text", "text": fmt.Sprintf("Extract all content from this document (%s) into clean markdown. Preserve headings, tables, lists, and structure. Do not summarise — extract everything.", filename)}, + } + for _, img := range batch { + content = append(content, map[string]interface{}{ + "type": "image_url", + "image_url": map[string]string{ + "url": "data:image/jpeg;base64," + img, + }, + }) + } + + body := map[string]interface{}{ + "model": visionModel, + "messages": []map[string]interface{}{ + { + "role": "system", + "content": "You are a document extraction expert. Extract ALL content from this document into clean markdown. Preserve headings, tables, lists, and structure. Do not summarise — extract everything.", + }, + { + "role": "user", + "content": content, + }, + }, + "max_tokens": 16384, + } + + result, err := c.chatCompletion(ctx, body) + if err != nil { + return fullMarkdown, fmt.Errorf("vision extraction batch %d: %w", i/maxImagesPerCall, err) + } + fullMarkdown += result + "\n" + } + + return fullMarkdown, nil +} + +// ExtractTextToMarkdown sends structured text (e.g. XLSX dump) to the model for markdown conversion. +func (c *Client) ExtractTextToMarkdown(ctx context.Context, textContent string, filename string) (string, error) { + body := map[string]interface{}{ + "model": visionModel, + "messages": []map[string]interface{}{ + { + "role": "system", + "content": "You are a document extraction expert. Convert the following structured data into clean markdown. Preserve tables, lists, and structure.", + }, + { + "role": "user", + "content": fmt.Sprintf("File: %s\n\n%s", filename, textContent), + }, + }, + "max_tokens": 16384, + } + return c.chatCompletion(ctx, body) +} + +func (c *Client) chatCompletion(ctx context.Context, body map[string]interface{}) (string, error) { + jsonBody, err := json.Marshal(body) + if err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/chat/completions", bytes.NewReader(jsonBody)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := c.http.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + return "", fmt.Errorf("fireworks API error %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Choices []struct { + Message struct { + Content string `json:"content"` + } `json:"message"` + } `json:"choices"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return "", fmt.Errorf("parse response: %w", err) + } + if len(result.Choices) == 0 { + return "", fmt.Errorf("no choices in response") + } + return result.Choices[0].Message.Content, nil +} + +// EmbedText generates embeddings for a batch of texts. +func (c *Client) EmbedText(ctx context.Context, texts []string) ([][]float32, error) { + var allEmbeddings [][]float32 + + for i := 0; i < len(texts); i += maxTextsPerBatch { + end := i + maxTextsPerBatch + if end > len(texts) { + end = len(texts) + } + batch := texts[i:end] + + body := map[string]interface{}{ + "model": embeddingModel, + "input": batch, + } + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewReader(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + return nil, fmt.Errorf("fireworks embedding API error %d: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Data []struct { + Embedding []float32 `json:"embedding"` + } `json:"data"` + } + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("parse embedding response: %w", err) + } + + for _, d := range result.Data { + allEmbeddings = append(allEmbeddings, d.Embedding) + } + } + + return allEmbeddings, nil +} + +// CosineSimilarity computes the cosine similarity between two vectors. +func CosineSimilarity(a, b []float32) float32 { + if len(a) != len(b) || len(a) == 0 { + return 0 + } + var dot, normA, normB float64 + for i := range a { + dot += float64(a[i]) * float64(b[i]) + normA += float64(a[i]) * float64(a[i]) + normB += float64(b[i]) * float64(b[i]) + } + denom := math.Sqrt(normA) * math.Sqrt(normB) + if denom == 0 { + return 0 + } + return float32(dot / denom) +} diff --git a/internal/handler/deals.go b/internal/handler/deals.go index 8ad965f..3cb4731 100644 --- a/internal/handler/deals.go +++ b/internal/handler/deals.go @@ -482,7 +482,7 @@ func (h *Handler) getFolders(dealID string) []*model.Folder { } func (h *Handler) getFiles(dealID string) []*model.File { - rows, err := h.db.Query("SELECT id, deal_id, folder_id, name, file_size, mime_type, status, uploaded_by, created_at FROM files WHERE deal_id = ? ORDER BY name", dealID) + rows, err := h.db.Query("SELECT id, deal_id, folder_id, name, file_size, mime_type, status, uploaded_by, created_at, COALESCE(response_id, '') FROM files WHERE deal_id = ? ORDER BY name", dealID) if err != nil { return nil } @@ -491,14 +491,18 @@ func (h *Handler) getFiles(dealID string) []*model.File { var files []*model.File for rows.Next() { f := &model.File{} - rows.Scan(&f.ID, &f.DealID, &f.FolderID, &f.Name, &f.FileSize, &f.MimeType, &f.Status, &f.UploadedBy, &f.CreatedAt) + rows.Scan(&f.ID, &f.DealID, &f.FolderID, &f.Name, &f.FileSize, &f.MimeType, &f.Status, &f.UploadedBy, &f.CreatedAt, &f.ResponseID) + // Load extraction status from responses table + if f.ResponseID != "" { + h.db.QueryRow("SELECT extraction_status FROM responses WHERE id = ?", f.ResponseID).Scan(&f.ExtractionStatus) + } files = append(files, f) } return files } func (h *Handler) getRequests(dealID string, profile *model.Profile) []*model.DiligenceRequest { - query := "SELECT id, deal_id, item_number, section, description, priority, atlas_status, atlas_note, confidence, buyer_comment, seller_comment, buyer_group, linked_file_ids, COALESCE(is_buyer_specific, 0), COALESCE(visible_to_buyer_group, '') FROM diligence_requests WHERE deal_id = ?" + query := "SELECT id, deal_id, item_number, section, description, priority, atlas_status, atlas_note, confidence, buyer_comment, seller_comment, buyer_group, linked_file_ids, COALESCE(is_buyer_specific, 0), COALESCE(visible_to_buyer_group, ''), COALESCE(assignee_id, ''), COALESCE(status, 'open') FROM diligence_requests WHERE deal_id = ?" args := []interface{}{dealID} if rbac.EffectiveIsBuyer(profile) { @@ -529,9 +533,19 @@ func (h *Handler) getRequests(dealID string, profile *model.Profile) []*model.Di var reqs []*model.DiligenceRequest for rows.Next() { r := &model.DiligenceRequest{} - rows.Scan(&r.ID, &r.DealID, &r.ItemNumber, &r.Section, &r.Description, &r.Priority, &r.AtlasStatus, &r.AtlasNote, &r.Confidence, &r.BuyerComment, &r.SellerComment, &r.BuyerGroup, &r.LinkedFileIDs, &r.IsBuyerSpecific, &r.VisibleToBuyerGroup) + rows.Scan(&r.ID, &r.DealID, &r.ItemNumber, &r.Section, &r.Description, &r.Priority, &r.AtlasStatus, &r.AtlasNote, &r.Confidence, &r.BuyerComment, &r.SellerComment, &r.BuyerGroup, &r.LinkedFileIDs, &r.IsBuyerSpecific, &r.VisibleToBuyerGroup, &r.AssigneeID, &r.Status) reqs = append(reqs, r) } + + // Enrich with assignee names and link counts + for _, r := range reqs { + if r.AssigneeID != "" { + h.db.QueryRow("SELECT full_name FROM profiles WHERE id = ?", r.AssigneeID).Scan(&r.AssigneeName) + } + h.db.QueryRow("SELECT COUNT(*) FROM request_links WHERE request_id = ? AND confirmed = 0 AND auto_linked = 1", r.ID).Scan(&r.PendingMatches) + h.db.QueryRow("SELECT COUNT(*) FROM request_links WHERE request_id = ? AND confirmed = 1", r.ID).Scan(&r.ConfirmedLinks) + } + return reqs } diff --git a/internal/handler/files.go b/internal/handler/files.go index 93a4434..0d4b633 100644 --- a/internal/handler/files.go +++ b/internal/handler/files.go @@ -86,6 +86,14 @@ func (h *Handler) handleFileUpload(w http.ResponseWriter, r *http.Request) { h.db.Exec("UPDATE diligence_requests SET linked_file_ids = ? WHERE id = ?", existing, requestItemID) } + // Create a response record for this document and enqueue extraction + respID := generateID("resp") + h.db.Exec( + `INSERT INTO responses (id, deal_id, type, title, file_id, extraction_status, created_by) VALUES (?, ?, 'document', ?, ?, 'pending', ?)`, + respID, dealID, header.Filename, fileID, profile.ID) + h.db.Exec("UPDATE files SET response_id = ? WHERE id = ?", respID, fileID) + h.enqueueExtraction(respID, storagePath, dealID) + // Log activity h.logActivity(dealID, profile.ID, profile.OrganizationID, "upload", "file", header.Filename, fileID) diff --git a/internal/handler/handler.go b/internal/handler/handler.go index b738633..50dc6aa 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -4,12 +4,16 @@ import ( "database/sql" "net/http" + "dealroom/internal/fireworks" "dealroom/internal/model" + "dealroom/internal/worker" ) type Handler struct { - db *sql.DB - config *Config + db *sql.DB + config *Config + extractor *worker.Extractor + fw *fireworks.Client } type Config struct { @@ -23,9 +27,25 @@ type Config struct { } func New(db *sql.DB, _ interface{}, config *Config) *Handler { + fw := fireworks.NewClient() + ext := worker.NewExtractor(db, fw) + ext.Start() return &Handler{ - db: db, - config: config, + db: db, + config: config, + extractor: ext, + fw: fw, + } +} + +// enqueueExtraction submits a job to the background extraction worker. +func (h *Handler) enqueueExtraction(responseID, filePath, dealID string) { + if h.extractor != nil { + h.extractor.Enqueue(worker.ExtractionJob{ + ResponseID: responseID, + FilePath: filePath, + DealID: dealID, + }) } } @@ -104,6 +124,14 @@ mux.HandleFunc("/auth/logout", h.handleLogout) // HTMX partials mux.HandleFunc("/htmx/request-comment", h.requireAuth(h.handleUpdateComment)) + + // Responses & AI matching + mux.HandleFunc("/deals/responses/statement", h.requireAuth(h.handleCreateStatement)) + mux.HandleFunc("/deals/responses/confirm", h.requireAuth(h.handleConfirmLink)) + mux.HandleFunc("/deals/responses/reject", h.requireAuth(h.handleRejectLink)) + mux.HandleFunc("/deals/responses/pending/", h.requireAuth(h.handlePendingLinks)) + mux.HandleFunc("/deals/assignment-rules/save", h.requireAuth(h.handleSaveAssignmentRules)) + mux.HandleFunc("/deals/assignment-rules/", h.requireAuth(h.handleGetAssignmentRules)) } // Middleware diff --git a/internal/handler/requests.go b/internal/handler/requests.go index d2e7b85..aa4dbcb 100644 --- a/internal/handler/requests.go +++ b/internal/handler/requests.go @@ -338,6 +338,9 @@ func (h *Handler) handleRequestListUpload(w http.ResponseWriter, r *http.Request // Auto-assign existing files to matching requests h.autoAssignFilesToRequests(dealID) + // Auto-assign by keyword rules + h.autoAssignByRules(dealID) + h.logActivity(dealID, profile.ID, profile.OrganizationID, "upload", "request_list", fmt.Sprintf("%d items", len(items)), "") http.Redirect(w, r, "/deals/"+dealID+"?tab=requests", http.StatusSeeOther) diff --git a/internal/handler/responses.go b/internal/handler/responses.go new file mode 100644 index 0000000..ba2bc05 --- /dev/null +++ b/internal/handler/responses.go @@ -0,0 +1,274 @@ +package handler + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + "time" +) + +func (h *Handler) handleCreateStatement(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + profile := getProfile(r.Context()) + + dealID := r.FormValue("deal_id") + title := strings.TrimSpace(r.FormValue("title")) + body := strings.TrimSpace(r.FormValue("body")) + + if dealID == "" || title == "" || body == "" { + http.Error(w, "deal_id, title, and body are required", 400) + return + } + + respID := generateID("resp") + _, err := h.db.Exec( + `INSERT INTO responses (id, deal_id, type, title, body, extraction_status, created_by) VALUES (?, ?, 'statement', ?, ?, 'pending', ?)`, + respID, dealID, title, body, profile.ID) + if err != nil { + http.Error(w, fmt.Sprintf("Error creating statement: %v", err), 500) + return + } + + // Enqueue for chunking + embedding + matching + if h.extractor != nil { + h.enqueueExtraction(respID, "", dealID) + } + + http.Redirect(w, r, "/deals/"+dealID+"?tab=requests", http.StatusSeeOther) +} + +func (h *Handler) handleConfirmLink(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + profile := getProfile(r.Context()) + + requestID := r.FormValue("request_id") + responseID := r.FormValue("response_id") + chunkID := r.FormValue("chunk_id") + + if requestID == "" || responseID == "" || chunkID == "" { + http.Error(w, "Missing fields", 400) + return + } + + _, err := h.db.Exec( + "UPDATE request_links SET confirmed = 1, confirmed_by = ?, confirmed_at = ? WHERE request_id = ? AND response_id = ? AND chunk_id = ?", + profile.ID, time.Now().UTC().Format("2006-01-02 15:04:05"), requestID, responseID, chunkID) + if err != nil { + http.Error(w, "Error confirming link", 500) + return + } + + // Update request status to answered if not already + h.db.Exec("UPDATE diligence_requests SET status = 'answered' WHERE id = ? AND status != 'answered'", requestID) + + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(`Confirmed`)) +} + +func (h *Handler) handleRejectLink(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + requestID := r.FormValue("request_id") + responseID := r.FormValue("response_id") + chunkID := r.FormValue("chunk_id") + + if requestID == "" || responseID == "" || chunkID == "" { + http.Error(w, "Missing fields", 400) + return + } + + h.db.Exec("DELETE FROM request_links WHERE request_id = ? AND response_id = ? AND chunk_id = ?", + requestID, responseID, chunkID) + + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(`Rejected`)) +} + +func (h *Handler) handlePendingLinks(w http.ResponseWriter, r *http.Request) { + dealID := strings.TrimPrefix(r.URL.Path, "/deals/responses/pending/") + if dealID == "" { + http.Error(w, "Missing deal ID", 400) + return + } + + rows, err := h.db.Query(` + SELECT rl.request_id, rl.response_id, rl.chunk_id, rl.confidence, + dr.description, r.title, r.type + FROM request_links rl + JOIN diligence_requests dr ON rl.request_id = dr.id + JOIN responses r ON rl.response_id = r.id + WHERE dr.deal_id = ? AND rl.confirmed = 0 AND rl.auto_linked = 1 + ORDER BY rl.confidence DESC + `, dealID) + if err != nil { + http.Error(w, "Error loading pending links", 500) + return + } + defer rows.Close() + + type pendingLink struct { + RequestID string `json:"request_id"` + ResponseID string `json:"response_id"` + ChunkID string `json:"chunk_id"` + Confidence float64 `json:"confidence"` + RequestDesc string `json:"request_desc"` + ResponseTitle string `json:"response_title"` + ResponseType string `json:"response_type"` + } + + var links []pendingLink + for rows.Next() { + var l pendingLink + rows.Scan(&l.RequestID, &l.ResponseID, &l.ChunkID, &l.Confidence, + &l.RequestDesc, &l.ResponseTitle, &l.ResponseType) + links = append(links, l) + } + + if links == nil { + links = []pendingLink{} + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(links) +} + +func (h *Handler) handleSaveAssignmentRules(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + dealID := r.FormValue("deal_id") + rulesJSON := r.FormValue("rules") + + if dealID == "" { + http.Error(w, "Missing deal_id", 400) + return + } + + type ruleInput struct { + Keyword string `json:"keyword"` + AssigneeID string `json:"assignee_id"` + } + var rules []ruleInput + if err := json.Unmarshal([]byte(rulesJSON), &rules); err != nil { + http.Error(w, "Invalid rules JSON", 400) + return + } + + // Delete existing rules and insert new set + h.db.Exec("DELETE FROM assignment_rules WHERE deal_id = ?", dealID) + for _, rule := range rules { + if rule.Keyword == "" || rule.AssigneeID == "" { + continue + } + id := generateID("rule") + h.db.Exec("INSERT INTO assignment_rules (id, deal_id, keyword, assignee_id) VALUES (?, ?, ?, ?)", + id, dealID, rule.Keyword, rule.AssigneeID) + } + + // Re-run auto-assignment + h.autoAssignByRules(dealID) + + http.Redirect(w, r, "/deals/"+dealID+"?tab=requests", http.StatusSeeOther) +} + +func (h *Handler) handleGetAssignmentRules(w http.ResponseWriter, r *http.Request) { + dealID := strings.TrimPrefix(r.URL.Path, "/deals/assignment-rules/") + if dealID == "" { + http.Error(w, "Missing deal ID", 400) + return + } + + rows, err := h.db.Query(` + SELECT ar.id, ar.keyword, ar.assignee_id, COALESCE(p.full_name, ar.assignee_id) + FROM assignment_rules ar + LEFT JOIN profiles p ON ar.assignee_id = p.id + WHERE ar.deal_id = ? + ORDER BY ar.keyword + `, dealID) + if err != nil { + http.Error(w, "Error loading rules", 500) + return + } + defer rows.Close() + + type ruleOut struct { + ID string `json:"id"` + Keyword string `json:"keyword"` + AssigneeID string `json:"assignee_id"` + AssigneeName string `json:"assignee_name"` + } + var rules []ruleOut + for rows.Next() { + var r ruleOut + rows.Scan(&r.ID, &r.Keyword, &r.AssigneeID, &r.AssigneeName) + rules = append(rules, r) + } + if rules == nil { + rules = []ruleOut{} + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(rules) +} + +// autoAssignByRules assigns unassigned requests based on keyword→assignee rules. +func (h *Handler) autoAssignByRules(dealID string) { + // Load rules + ruleRows, err := h.db.Query("SELECT keyword, assignee_id FROM assignment_rules WHERE deal_id = ?", dealID) + if err != nil { + return + } + defer ruleRows.Close() + + type rule struct { + keyword, assigneeID string + } + var rules []rule + for ruleRows.Next() { + var r rule + ruleRows.Scan(&r.keyword, &r.assigneeID) + rules = append(rules, r) + } + + if len(rules) == 0 { + return + } + + // Load unassigned requests + reqRows, err := h.db.Query("SELECT id, section, description FROM diligence_requests WHERE deal_id = ? AND (assignee_id = '' OR assignee_id IS NULL)", dealID) + if err != nil { + return + } + defer reqRows.Close() + + type reqInfo struct { + id, section, desc string + } + var reqs []reqInfo + for reqRows.Next() { + var r reqInfo + reqRows.Scan(&r.id, &r.section, &r.desc) + reqs = append(reqs, r) + } + + for _, req := range reqs { + text := strings.ToLower(req.section + " " + req.desc) + for _, rule := range rules { + if strings.Contains(text, strings.ToLower(rule.keyword)) { + h.db.Exec("UPDATE diligence_requests SET assignee_id = ? WHERE id = ?", rule.assigneeID, req.id) + break + } + } + } +} diff --git a/internal/model/models.go b/internal/model/models.go index 6ca5e9a..a2c15b4 100644 --- a/internal/model/models.go +++ b/internal/model/models.go @@ -73,16 +73,18 @@ type Folder struct { } type File struct { - ID string - DealID string - FolderID string - Name string - FileSize int64 - MimeType string - Status string // uploaded, processing, reviewed, flagged, archived - StoragePath string - UploadedBy string - CreatedAt time.Time + ID string + DealID string + FolderID string + Name string + FileSize int64 + MimeType string + Status string // uploaded, processing, reviewed, flagged, archived + StoragePath string + ResponseID string + UploadedBy string + CreatedAt time.Time + ExtractionStatus string // computed from responses table } type DiligenceRequest struct { @@ -101,9 +103,15 @@ type DiligenceRequest struct { LinkedFileIDs string IsBuyerSpecific bool VisibleToBuyerGroup string + AssigneeID string + Status string // open, in_progress, answered, not_applicable CreatedBy string CreatedAt time.Time UpdatedAt time.Time + // Computed + AssigneeName string + PendingMatches int + ConfirmedLinks int } type Contact struct { diff --git a/internal/worker/extractor.go b/internal/worker/extractor.go new file mode 100644 index 0000000..1e267cc --- /dev/null +++ b/internal/worker/extractor.go @@ -0,0 +1,217 @@ +package worker + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/binary" + "encoding/hex" + "fmt" + "log" + "math" + "time" + + "dealroom/internal/extract" + "dealroom/internal/fireworks" +) + +const matchThreshold = 0.72 + +type ExtractionJob struct { + ResponseID string + FilePath string // absolute path to uploaded file (or "" for statements) + DealID string +} + +type Extractor struct { + db *sql.DB + fw *fireworks.Client + jobs chan ExtractionJob +} + +func NewExtractor(db *sql.DB, fw *fireworks.Client) *Extractor { + return &Extractor{ + db: db, + fw: fw, + jobs: make(chan ExtractionJob, 100), + } +} + +func (e *Extractor) Start() { + for i := 0; i < 2; i++ { + go e.worker(i) + } + log.Println("Extraction worker started (2 goroutines)") +} + +func (e *Extractor) Enqueue(job ExtractionJob) { + e.jobs <- job +} + +func (e *Extractor) worker(id int) { + for job := range e.jobs { + e.process(id, job) + } +} + +func (e *Extractor) process(workerID int, job ExtractionJob) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + log.Printf("[extractor-%d] Processing response %s (deal=%s, file=%s)", workerID, job.ResponseID, job.DealID, job.FilePath) + + // Set status to processing + e.db.Exec("UPDATE responses SET extraction_status = 'processing', updated_at = datetime('now') WHERE id = ?", job.ResponseID) + + var body string + + if job.FilePath != "" { + // Document: extract text from file + md, err := e.extractFile(ctx, job) + if err != nil { + log.Printf("[extractor-%d] Extraction failed for %s: %v", workerID, job.ResponseID, err) + e.db.Exec("UPDATE responses SET extraction_status = 'failed', updated_at = datetime('now') WHERE id = ?", job.ResponseID) + return + } + body = md + // Update response body and status + e.db.Exec("UPDATE responses SET body = ?, extraction_status = 'done', updated_at = datetime('now') WHERE id = ?", body, job.ResponseID) + } else { + // Statement: body is already set, just mark done + e.db.Exec("UPDATE responses SET extraction_status = 'done', updated_at = datetime('now') WHERE id = ?", job.ResponseID) + // Load existing body + e.db.QueryRow("SELECT body FROM responses WHERE id = ?", job.ResponseID).Scan(&body) + } + + if body == "" { + log.Printf("[extractor-%d] Empty body for response %s, skipping chunk+match", workerID, job.ResponseID) + return + } + + // Chunk + chunks := extract.ChunkMarkdown(body) + if len(chunks) == 0 { + log.Printf("[extractor-%d] No chunks produced for response %s", workerID, job.ResponseID) + return + } + + // Embed chunks + chunkVectors, err := e.fw.EmbedText(ctx, chunks) + if err != nil { + log.Printf("[extractor-%d] Embedding failed for %s: %v", workerID, job.ResponseID, err) + e.db.Exec("UPDATE responses SET extraction_status = 'failed', updated_at = datetime('now') WHERE id = ?", job.ResponseID) + return + } + + // Store chunks + chunkIDs := make([]string, len(chunks)) + for i, chunk := range chunks { + chunkID := generateID("chunk") + chunkIDs[i] = chunkID + vecBytes := float32sToBytes(chunkVectors[i]) + e.db.Exec("INSERT INTO response_chunks (id, response_id, chunk_index, text, vector) VALUES (?, ?, ?, ?, ?)", + chunkID, job.ResponseID, i, chunk, vecBytes) + } + + // Match against open requests in this deal + linkCount := e.matchRequests(ctx, job.DealID, job.ResponseID, chunkIDs, chunkVectors) + + log.Printf("[extractor-%d] Response %s: %d chunks, %d request links auto-created", workerID, job.ResponseID, len(chunks), linkCount) +} + +func (e *Extractor) extractFile(ctx context.Context, job ExtractionJob) (string, error) { + if extract.IsXLSX(job.FilePath) { + // XLSX: extract text dump, send as text to LLM + text, err := extract.XLSXToText(job.FilePath) + if err != nil { + return "", fmt.Errorf("xlsx extract: %w", err) + } + md, err := e.fw.ExtractTextToMarkdown(ctx, text, job.FilePath) + if err != nil { + return "", fmt.Errorf("xlsx to markdown: %w", err) + } + return md, nil + } + + // PDF or image + images, err := extract.FileToImages(job.FilePath) + if err != nil { + return "", fmt.Errorf("file to images: %w", err) + } + if len(images) == 0 { + return "", fmt.Errorf("no images extracted from file") + } + + md, err := e.fw.ExtractToMarkdown(ctx, images, job.FilePath) + if err != nil { + return "", fmt.Errorf("vision extraction: %w", err) + } + return md, nil +} + +func (e *Extractor) matchRequests(ctx context.Context, dealID, responseID string, chunkIDs []string, chunkVectors [][]float32) int { + // Load all requests for this deal + rows, err := e.db.Query("SELECT id, description FROM diligence_requests WHERE deal_id = ?", dealID) + if err != nil { + log.Printf("[extractor] Failed to load requests for deal %s: %v", dealID, err) + return 0 + } + defer rows.Close() + + type reqInfo struct { + id, desc string + } + var reqs []reqInfo + for rows.Next() { + var r reqInfo + rows.Scan(&r.id, &r.desc) + reqs = append(reqs, r) + } + + if len(reqs) == 0 { + return 0 + } + + // Embed request descriptions + descs := make([]string, len(reqs)) + for i, r := range reqs { + descs[i] = r.desc + } + reqVectors, err := e.fw.EmbedText(ctx, descs) + if err != nil { + log.Printf("[extractor] Failed to embed request descriptions: %v", err) + return 0 + } + + // Match each (chunk, request) pair + linkCount := 0 + for ci, chunkVec := range chunkVectors { + for ri, reqVec := range reqVectors { + sim := fireworks.CosineSimilarity(chunkVec, reqVec) + if sim >= matchThreshold { + _, err := e.db.Exec( + "INSERT OR IGNORE INTO request_links (request_id, response_id, chunk_id, confidence, auto_linked, confirmed) VALUES (?, ?, ?, ?, 1, 0)", + reqs[ri].id, responseID, chunkIDs[ci], sim) + if err == nil { + linkCount++ + } + } + } + } + + return linkCount +} + +func float32sToBytes(fs []float32) []byte { + buf := make([]byte, len(fs)*4) + for i, f := range fs { + binary.LittleEndian.PutUint32(buf[i*4:], math.Float32bits(f)) + } + return buf +} + +func generateID(prefix string) string { + b := make([]byte, 8) + rand.Read(b) + return prefix + "-" + hex.EncodeToString(b) +} diff --git a/templates/dealroom.templ b/templates/dealroom.templ index cfece0e..3ef41ff 100644 --- a/templates/dealroom.templ +++ b/templates/dealroom.templ @@ -160,6 +160,7 @@ templ DealRoomDetail(profile *model.Profile, deal *model.Deal, folders []*model.