dealspace/api/mcp.go

509 lines
16 KiB
Go

package api
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/mish/dealspace/lib"
)
// NewMCPServer creates an MCP server with all Dealspace tools registered.
func NewMCPServer(db *lib.DB, cfg *lib.Config) *server.MCPServer {
s := server.NewMCPServer(
"dealspace",
"1.0.0",
server.WithToolCapabilities(false),
server.WithRecovery(),
)
s.AddTool(listProjectsTool(), listProjectsHandler(db, cfg))
s.AddTool(getProjectTool(), getProjectHandler(db, cfg))
s.AddTool(listRequestsTool(), listRequestsHandler(db, cfg))
s.AddTool(getRequestTool(), getRequestHandler(db, cfg))
s.AddTool(searchRequestsTool(), searchRequestsHandler(db, cfg))
s.AddTool(updateRequestStatusTool(), updateRequestStatusHandler(db, cfg))
s.AddTool(addCommentTool(), addCommentHandler(db, cfg))
s.AddTool(listAnswersTool(), listAnswersHandler(db, cfg))
s.AddTool(linkAnswerTool(), linkAnswerHandler(db))
s.AddTool(listTasksTool(), listTasksHandler(db, cfg))
return s
}
// MCPContextFunc propagates user_id from OAuth middleware into MCP tool context.
func MCPContextFunc(ctx context.Context, r interface{ Context() context.Context }) context.Context {
if uid := UserIDFromContext(r.Context()); uid != "" {
return context.WithValue(ctx, ctxUserID, uid)
}
return ctx
}
// --- Tool definitions ---
func listProjectsTool() mcp.Tool {
return mcp.NewTool("list_projects",
mcp.WithDescription("List all projects the user has access to"),
)
}
func getProjectTool() mcp.Tool {
return mcp.NewTool("get_project",
mcp.WithDescription("Get project details including workstreams and summary"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
)
}
func listRequestsTool() mcp.Tool {
return mcp.NewTool("list_requests",
mcp.WithDescription("List all requests in a project as a tree (request lists, sections, requests)"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
)
}
func getRequestTool() mcp.Tool {
return mcp.NewTool("get_request",
mcp.WithDescription("Get request details including linked answers"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
)
}
func searchRequestsTool() mcp.Tool {
return mcp.NewTool("search_requests",
mcp.WithDescription("Search requests by keyword in title, description, or comments"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
mcp.WithString("query", mcp.Description("Search keyword"), mcp.Required()),
)
}
func updateRequestStatusTool() mcp.Tool {
return mcp.NewTool("update_request_status",
mcp.WithDescription("Update the status of a request"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
mcp.WithString("status", mcp.Description("New status: open, assigned, answered, review, or published"), mcp.Required()),
)
}
func addCommentTool() mcp.Tool {
return mcp.NewTool("add_comment",
mcp.WithDescription("Add a buyer or seller comment to a request"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
mcp.WithString("comment", mcp.Description("The comment text"), mcp.Required()),
mcp.WithString("side", mcp.Description("Which side: buyer or seller"), mcp.Required()),
)
}
func listAnswersTool() mcp.Tool {
return mcp.NewTool("list_answers",
mcp.WithDescription("List answer documents in a project"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
)
}
func linkAnswerTool() mcp.Tool {
return mcp.NewTool("link_answer",
mcp.WithDescription("Link an existing answer document to a request"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
mcp.WithString("answer_id", mcp.Description("The answer entry ID"), mcp.Required()),
)
}
func listTasksTool() mcp.Tool {
return mcp.NewTool("list_tasks",
mcp.WithDescription("List tasks assigned to the current user"),
mcp.WithString("project_id", mcp.Description("Optional project ID to filter tasks")),
)
}
// --- Tool handlers ---
func listProjectsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projects, err := lib.ProjectsByUser(db, cfg, userID)
if err != nil {
return mcp.NewToolResultError("Failed to list projects: " + err.Error()), nil
}
var lines []string
for _, p := range projects {
var data map[string]any
if p.DataText != "" {
_ = json.Unmarshal([]byte(p.DataText), &data)
}
name, _ := data["name"].(string)
if name == "" {
name = p.EntryID
}
lines = append(lines, fmt.Sprintf("- %s (id: %s, stage: %s)", name, p.EntryID, p.Stage))
}
if len(lines) == 0 {
return mcp.NewToolResultText("No projects found."), nil
}
return mcp.NewToolResultText("Projects:\n" + strings.Join(lines, "\n")), nil
}
}
func getProjectHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
entry, err := lib.EntryByID(db, cfg, projectID)
if err != nil {
return mcp.NewToolResultError("Project not found"), nil
}
if entry.Type != lib.TypeProject {
return mcp.NewToolResultError("Not a project"), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Project: %s\nID: %s\nStage: %s\nCreated: %d\nData: %s",
entry.SummaryText, entry.EntryID, entry.Stage, entry.CreatedAt, entry.DataText)), nil
}
}
func listRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
// Verify access
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
return mcp.NewToolResultError("Access denied"), nil
}
filter := lib.EntryFilter{ProjectID: projectID}
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
if err != nil {
return mcp.NewToolResultError("Failed to list entries: " + err.Error()), nil
}
// Filter to request types and build tree text
var lines []string
for _, e := range entries {
if e.Type != lib.TypeRequestList && e.Type != lib.TypeSection && e.Type != lib.TypeRequest {
continue
}
indent := strings.Repeat(" ", e.Depth)
var data map[string]any
if e.DataText != "" {
_ = json.Unmarshal([]byte(e.DataText), &data)
}
switch e.Type {
case lib.TypeRequestList:
name, _ := data["name"].(string)
lines = append(lines, fmt.Sprintf("%s[Request List] %s (id: %s)", indent, name, e.EntryID))
case lib.TypeSection:
name, _ := data["name"].(string)
lines = append(lines, fmt.Sprintf("%s[Section] %s (id: %s)", indent, name, e.EntryID))
case lib.TypeRequest:
title, _ := data["title"].(string)
status, _ := data["status"].(string)
priority, _ := data["priority"].(string)
lines = append(lines, fmt.Sprintf("%s- %s [%s/%s] (id: %s)", indent, title, status, priority, e.EntryID))
}
}
if len(lines) == 0 {
return mcp.NewToolResultText("No requests found in this project."), nil
}
return mcp.NewToolResultText("Request tree:\n" + strings.Join(lines, "\n")), nil
}
}
func getRequestHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
entry, err := lib.EntryByID(db, cfg, requestID)
if err != nil {
return mcp.NewToolResultError("Request not found"), nil
}
result := fmt.Sprintf("Request ID: %s\nProject: %s\nType: %s\nData: %s",
entry.EntryID, entry.ProjectID, entry.Type, entry.DataText)
// Get linked answers
links, err := lib.AnswerLinksByRequest(db, requestID)
if err == nil && len(links) > 0 {
result += "\n\nLinked Answers:"
for _, link := range links {
answerEntry, err := lib.EntryByID(db, cfg, link.AnswerID)
if err == nil {
result += fmt.Sprintf("\n - %s (id: %s)", answerEntry.SummaryText, link.AnswerID)
}
}
}
return mcp.NewToolResultText(result), nil
}
}
func searchRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
query, err := req.RequireString("query")
if err != nil {
return mcp.NewToolResultError("query is required"), nil
}
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
return mcp.NewToolResultError("Access denied"), nil
}
filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeRequest}
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
if err != nil {
return mcp.NewToolResultError("Failed to search: " + err.Error()), nil
}
queryLower := strings.ToLower(query)
var lines []string
for _, e := range entries {
if strings.Contains(strings.ToLower(e.DataText), queryLower) ||
strings.Contains(strings.ToLower(e.SummaryText), queryLower) {
var data map[string]any
if e.DataText != "" {
_ = json.Unmarshal([]byte(e.DataText), &data)
}
title, _ := data["title"].(string)
status, _ := data["status"].(string)
lines = append(lines, fmt.Sprintf("- %s [%s] (id: %s)", title, status, e.EntryID))
}
}
if len(lines) == 0 {
return mcp.NewToolResultText("No requests matching '" + query + "' found."), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Found %d results:\n%s", len(lines), strings.Join(lines, "\n"))), nil
}
}
func updateRequestStatusHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
status, err := req.RequireString("status")
if err != nil {
return mcp.NewToolResultError("status is required"), nil
}
validStatuses := map[string]bool{"open": true, "assigned": true, "answered": true, "review": true, "published": true}
if !validStatuses[status] {
return mcp.NewToolResultError("Invalid status. Must be: open, assigned, answered, review, or published"), nil
}
entry, err := lib.EntryByID(db, cfg, requestID)
if err != nil {
return mcp.NewToolResultError("Request not found"), nil
}
if entry.Type != lib.TypeRequest {
return mcp.NewToolResultError("Entry is not a request"), nil
}
// Parse current data, update status
var data lib.RequestData
if entry.DataText != "" {
_ = json.Unmarshal([]byte(entry.DataText), &data)
}
data.Status = status
dataJSON, _ := json.Marshal(data)
entry.DataText = string(dataJSON)
if err := lib.EntryWrite(db, cfg, userID, entry); err != nil {
return mcp.NewToolResultError("Failed to update: " + err.Error()), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Updated request %s status to '%s'", requestID, status)), nil
}
}
func addCommentHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
comment, err := req.RequireString("comment")
if err != nil {
return mcp.NewToolResultError("comment is required"), nil
}
side, err := req.RequireString("side")
if err != nil {
return mcp.NewToolResultError("side is required"), nil
}
if side != "buyer" && side != "seller" {
return mcp.NewToolResultError("side must be 'buyer' or 'seller'"), nil
}
entry, err := lib.EntryByID(db, cfg, requestID)
if err != nil {
return mcp.NewToolResultError("Request not found"), nil
}
if entry.Type != lib.TypeRequest {
return mcp.NewToolResultError("Entry is not a request"), nil
}
var data lib.RequestData
if entry.DataText != "" {
_ = json.Unmarshal([]byte(entry.DataText), &data)
}
if side == "buyer" {
data.BuyerComment = comment
} else {
data.SellerComment = comment
}
dataJSON, _ := json.Marshal(data)
entry.DataText = string(dataJSON)
if err := lib.EntryWrite(db, cfg, userID, entry); err != nil {
return mcp.NewToolResultError("Failed to update: " + err.Error()), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Added %s comment to request %s", side, requestID)), nil
}
}
func listAnswersHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
return mcp.NewToolResultError("Access denied"), nil
}
filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeAnswer}
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
if err != nil {
return mcp.NewToolResultError("Failed to list answers: " + err.Error()), nil
}
var lines []string
for _, e := range entries {
lines = append(lines, fmt.Sprintf("- %s (id: %s)", e.SummaryText, e.EntryID))
}
if len(lines) == 0 {
return mcp.NewToolResultText("No answer documents found."), nil
}
return mcp.NewToolResultText("Answers:\n" + strings.Join(lines, "\n")), nil
}
}
func linkAnswerHandler(db *lib.DB) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
answerID, err := req.RequireString("answer_id")
if err != nil {
return mcp.NewToolResultError("answer_id is required"), nil
}
if err := lib.AnswerLinkCreate(db, answerID, requestID, userID); err != nil {
return mcp.NewToolResultError("Failed to link answer: " + err.Error()), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Linked answer %s to request %s", answerID, requestID)), nil
}
}
func listTasksHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
tasks, err := lib.TasksByUser(db, cfg, userID)
if err != nil {
return mcp.NewToolResultError("Failed to list tasks: " + err.Error()), nil
}
projectID := req.GetString("project_id", "")
var lines []string
for _, t := range tasks {
if projectID != "" && t.ProjectID != projectID {
continue
}
lines = append(lines, fmt.Sprintf("- %s (id: %s, project: %s, stage: %s)",
t.SummaryText, t.EntryID, t.ProjectID, t.Stage))
}
if len(lines) == 0 {
return mcp.NewToolResultText("No tasks assigned to you."), nil
}
return mcp.NewToolResultText("Your tasks:\n" + strings.Join(lines, "\n")), nil
}
}