215 lines
5.7 KiB
Go
215 lines
5.7 KiB
Go
package fireworks
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
)
|
|
|
|
const (
|
|
baseURL = "https://api.fireworks.ai/inference/v1"
|
|
apiKey = "fw_RVcDe4c6mN4utKLsgA7hTm"
|
|
visionModel = "accounts/fireworks/models/llama-v3p2-90b-vision-instruct"
|
|
embeddingModel = "nomic-ai/nomic-embed-text-v1.5"
|
|
maxImagesPerCall = 10
|
|
maxTextsPerBatch = 50
|
|
)
|
|
|
|
type Client struct {
|
|
http *http.Client
|
|
}
|
|
|
|
func NewClient() *Client {
|
|
return &Client{http: &http.Client{}}
|
|
}
|
|
|
|
// ExtractToMarkdown sends base64 images to the vision model and returns extracted markdown.
|
|
// For XLSX text content, pass nil images and set textContent instead.
|
|
func (c *Client) ExtractToMarkdown(ctx context.Context, imageBase64 []string, filename string) (string, error) {
|
|
if len(imageBase64) == 0 {
|
|
return "", fmt.Errorf("no images provided")
|
|
}
|
|
|
|
var fullMarkdown string
|
|
|
|
// Batch images into groups of maxImagesPerCall
|
|
for i := 0; i < len(imageBase64); i += maxImagesPerCall {
|
|
end := i + maxImagesPerCall
|
|
if end > len(imageBase64) {
|
|
end = len(imageBase64)
|
|
}
|
|
batch := imageBase64[i:end]
|
|
|
|
content := []map[string]interface{}{
|
|
{"type": "text", "text": fmt.Sprintf("Extract all content from this document (%s) into clean markdown. Preserve headings, tables, lists, and structure. Do not summarise — extract everything.", filename)},
|
|
}
|
|
for _, img := range batch {
|
|
content = append(content, map[string]interface{}{
|
|
"type": "image_url",
|
|
"image_url": map[string]string{
|
|
"url": "data:image/jpeg;base64," + img,
|
|
},
|
|
})
|
|
}
|
|
|
|
body := map[string]interface{}{
|
|
"model": visionModel,
|
|
"messages": []map[string]interface{}{
|
|
{
|
|
"role": "system",
|
|
"content": "You are a document extraction expert. Extract ALL content from this document into clean markdown. Preserve headings, tables, lists, and structure. Do not summarise — extract everything.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": content,
|
|
},
|
|
},
|
|
"max_tokens": 16384,
|
|
}
|
|
|
|
result, err := c.chatCompletion(ctx, body)
|
|
if err != nil {
|
|
return fullMarkdown, fmt.Errorf("vision extraction batch %d: %w", i/maxImagesPerCall, err)
|
|
}
|
|
fullMarkdown += result + "\n"
|
|
}
|
|
|
|
return fullMarkdown, nil
|
|
}
|
|
|
|
// ExtractTextToMarkdown sends structured text (e.g. XLSX dump) to the model for markdown conversion.
|
|
func (c *Client) ExtractTextToMarkdown(ctx context.Context, textContent string, filename string) (string, error) {
|
|
body := map[string]interface{}{
|
|
"model": visionModel,
|
|
"messages": []map[string]interface{}{
|
|
{
|
|
"role": "system",
|
|
"content": "You are a document extraction expert. Convert the following structured data into clean markdown. Preserve tables, lists, and structure.",
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": fmt.Sprintf("File: %s\n\n%s", filename, textContent),
|
|
},
|
|
},
|
|
"max_tokens": 16384,
|
|
}
|
|
return c.chatCompletion(ctx, body)
|
|
}
|
|
|
|
func (c *Client) chatCompletion(ctx context.Context, body map[string]interface{}) (string, error) {
|
|
jsonBody, err := json.Marshal(body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/chat/completions", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != 200 {
|
|
return "", fmt.Errorf("fireworks API error %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return "", fmt.Errorf("parse response: %w", err)
|
|
}
|
|
if len(result.Choices) == 0 {
|
|
return "", fmt.Errorf("no choices in response")
|
|
}
|
|
return result.Choices[0].Message.Content, nil
|
|
}
|
|
|
|
// EmbedText generates embeddings for a batch of texts.
|
|
func (c *Client) EmbedText(ctx context.Context, texts []string) ([][]float32, error) {
|
|
var allEmbeddings [][]float32
|
|
|
|
for i := 0; i < len(texts); i += maxTextsPerBatch {
|
|
end := i + maxTextsPerBatch
|
|
if end > len(texts) {
|
|
end = len(texts)
|
|
}
|
|
batch := texts[i:end]
|
|
|
|
body := map[string]interface{}{
|
|
"model": embeddingModel,
|
|
"input": batch,
|
|
}
|
|
jsonBody, err := json.Marshal(body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", baseURL+"/embeddings", bytes.NewReader(jsonBody))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
respBody, _ := io.ReadAll(resp.Body)
|
|
if resp.StatusCode != 200 {
|
|
return nil, fmt.Errorf("fireworks embedding API error %d: %s", resp.StatusCode, string(respBody))
|
|
}
|
|
|
|
var result struct {
|
|
Data []struct {
|
|
Embedding []float32 `json:"embedding"`
|
|
} `json:"data"`
|
|
}
|
|
if err := json.Unmarshal(respBody, &result); err != nil {
|
|
return nil, fmt.Errorf("parse embedding response: %w", err)
|
|
}
|
|
|
|
for _, d := range result.Data {
|
|
allEmbeddings = append(allEmbeddings, d.Embedding)
|
|
}
|
|
}
|
|
|
|
return allEmbeddings, nil
|
|
}
|
|
|
|
// CosineSimilarity computes the cosine similarity between two vectors.
|
|
func CosineSimilarity(a, b []float32) float32 {
|
|
if len(a) != len(b) || len(a) == 0 {
|
|
return 0
|
|
}
|
|
var dot, normA, normB float64
|
|
for i := range a {
|
|
dot += float64(a[i]) * float64(b[i])
|
|
normA += float64(a[i]) * float64(a[i])
|
|
normB += float64(b[i]) * float64(b[i])
|
|
}
|
|
denom := math.Sqrt(normA) * math.Sqrt(normB)
|
|
if denom == 0 {
|
|
return 0
|
|
}
|
|
return float32(dot / denom)
|
|
}
|