dealroom/internal/fireworks/client.go

215 lines
5.7 KiB
Go

package fireworks
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
)
const (
baseURL = "https://api.fireworks.ai/inference/v1"
apiKey = "fw_RVcDe4c6mN4utKLsgA7hTm"
visionModel = "accounts/fireworks/models/llama-v3p2-90b-vision-instruct"
embeddingModel = "nomic-ai/nomic-embed-text-v1.5"
maxImagesPerCall = 10
maxTextsPerBatch = 50
)
type Client struct {
http *http.Client
}
func NewClient() *Client {
return &Client{http: &http.Client{}}
}
// ExtractToMarkdown sends base64 images to the vision model and returns extracted markdown.
// For XLSX text content, pass nil images and set textContent instead.
func (c *Client) ExtractToMarkdown(ctx context.Context, imageBase64 []string, filename string) (string, error) {
if len(imageBase64) == 0 {
return "", fmt.Errorf("no images provided")
}
var fullMarkdown string
// Batch images into groups of maxImagesPerCall
for i := 0; i < len(imageBase64); i += maxImagesPerCall {
end := i + maxImagesPerCall
if end > len(imageBase64) {
end = len(imageBase64)
}
batch := imageBase64[i:end]
content := []map[string]interface{}{
{"type": "text", "text": fmt.Sprintf("Extract all content from this document (%s) into clean markdown. Preserve headings, tables, lists, and structure. Do not summarise — extract everything.", filename)},
}
for _, img := range batch {
content = append(content, map[string]interface{}{
"type": "image_url",
"image_url": map[string]string{
"url": "data:image/jpeg;base64," + img,
},
})
}
body := map[string]interface{}{
"model": visionModel,
"messages": []map[string]interface{}{
{
"role": "system",
"content": "You are a document extraction expert. Extract ALL content from this document into clean markdown. Preserve headings, tables, lists, and structure. Do not summarise — extract everything.",
},
{
"role": "user",
"content": content,
},
},
"max_tokens": 16384,
}
result, err := c.chatCompletion(ctx, body)
if err != nil {
return fullMarkdown, fmt.Errorf("vision extraction batch %d: %w", i/maxImagesPerCall, err)
}
fullMarkdown += result + "\n"
}
return fullMarkdown, nil
}
// ExtractTextToMarkdown sends structured text (e.g. XLSX dump) to the model for markdown conversion.
func (c *Client) ExtractTextToMarkdown(ctx context.Context, textContent string, filename string) (string, error) {
body := map[string]interface{}{
"model": visionModel,
"messages": []map[string]interface{}{
{
"role": "system",
"content": "You are a document extraction expert. Convert the following structured data into clean markdown. Preserve tables, lists, and structure.",
},
{
"role": "user",
"content": fmt.Sprintf("File: %s\n\n%s", filename, textContent),
},
},
"max_tokens": 16384,
}
return c.chatCompletion(ctx, body)
}
func (c *Client) chatCompletion(ctx context.Context, body map[string]interface{}) (string, error) {
jsonBody, err := json.Marshal(body)
if err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/chat/completions", bytes.NewReader(jsonBody))
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := c.http.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return "", fmt.Errorf("fireworks API error %d: %s", resp.StatusCode, string(respBody))
}
var result struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return "", fmt.Errorf("parse response: %w", err)
}
if len(result.Choices) == 0 {
return "", fmt.Errorf("no choices in response")
}
return result.Choices[0].Message.Content, nil
}
// EmbedText generates embeddings for a batch of texts.
func (c *Client) EmbedText(ctx context.Context, texts []string) ([][]float32, error) {
var allEmbeddings [][]float32
for i := 0; i < len(texts); i += maxTextsPerBatch {
end := i + maxTextsPerBatch
if end > len(texts) {
end = len(texts)
}
batch := texts[i:end]
body := map[string]interface{}{
"model": embeddingModel,
"input": batch,
}
jsonBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewReader(jsonBody))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+apiKey)
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return nil, fmt.Errorf("fireworks embedding API error %d: %s", resp.StatusCode, string(respBody))
}
var result struct {
Data []struct {
Embedding []float32 `json:"embedding"`
} `json:"data"`
}
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("parse embedding response: %w", err)
}
for _, d := range result.Data {
allEmbeddings = append(allEmbeddings, d.Embedding)
}
}
return allEmbeddings, nil
}
// CosineSimilarity computes the cosine similarity between two vectors.
func CosineSimilarity(a, b []float32) float32 {
if len(a) != len(b) || len(a) == 0 {
return 0
}
var dot, normA, normB float64
for i := range a {
dot += float64(a[i]) * float64(b[i])
normA += float64(a[i]) * float64(a[i])
normB += float64(b[i]) * float64(b[i])
}
denom := math.Sqrt(normA) * math.Sqrt(normB)
if denom == 0 {
return 0
}
return float32(dot / denom)
}