package api import ( "context" "encoding/json" "fmt" "strings" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/mish/dealspace/lib" ) // NewMCPServer creates an MCP server with all Dealspace tools registered. func NewMCPServer(db *lib.DB, cfg *lib.Config) *server.MCPServer { s := server.NewMCPServer( "dealspace", "1.0.0", server.WithToolCapabilities(false), server.WithRecovery(), ) s.AddTool(listProjectsTool(), listProjectsHandler(db, cfg)) s.AddTool(getProjectTool(), getProjectHandler(db, cfg)) s.AddTool(listRequestsTool(), listRequestsHandler(db, cfg)) s.AddTool(getRequestTool(), getRequestHandler(db, cfg)) s.AddTool(searchRequestsTool(), searchRequestsHandler(db, cfg)) s.AddTool(updateRequestStatusTool(), updateRequestStatusHandler(db, cfg)) s.AddTool(addCommentTool(), addCommentHandler(db, cfg)) s.AddTool(listAnswersTool(), listAnswersHandler(db, cfg)) s.AddTool(linkAnswerTool(), linkAnswerHandler(db)) s.AddTool(listTasksTool(), listTasksHandler(db, cfg)) return s } // MCPContextFunc propagates user_id from OAuth middleware into MCP tool context. func MCPContextFunc(ctx context.Context, r interface{ Context() context.Context }) context.Context { if uid := UserIDFromContext(r.Context()); uid != "" { return context.WithValue(ctx, ctxUserID, uid) } return ctx } // --- Tool definitions --- func listProjectsTool() mcp.Tool { return mcp.NewTool("list_projects", mcp.WithDescription("List all projects the user has access to"), ) } func getProjectTool() mcp.Tool { return mcp.NewTool("get_project", mcp.WithDescription("Get project details including workstreams and summary"), mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), ) } func listRequestsTool() mcp.Tool { return mcp.NewTool("list_requests", mcp.WithDescription("List all requests in a project as a tree (request lists, sections, requests)"), mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), ) } func getRequestTool() mcp.Tool { return mcp.NewTool("get_request", mcp.WithDescription("Get request details including linked answers"), mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), ) } func searchRequestsTool() mcp.Tool { return mcp.NewTool("search_requests", mcp.WithDescription("Search requests by keyword in title, description, or comments"), mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), mcp.WithString("query", mcp.Description("Search keyword"), mcp.Required()), ) } func updateRequestStatusTool() mcp.Tool { return mcp.NewTool("update_request_status", mcp.WithDescription("Update the status of a request"), mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), mcp.WithString("status", mcp.Description("New status: open, assigned, answered, review, or published"), mcp.Required()), ) } func addCommentTool() mcp.Tool { return mcp.NewTool("add_comment", mcp.WithDescription("Add a buyer or seller comment to a request"), mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), mcp.WithString("comment", mcp.Description("The comment text"), mcp.Required()), mcp.WithString("side", mcp.Description("Which side: buyer or seller"), mcp.Required()), ) } func listAnswersTool() mcp.Tool { return mcp.NewTool("list_answers", mcp.WithDescription("List answer documents in a project"), mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), ) } func linkAnswerTool() mcp.Tool { return mcp.NewTool("link_answer", mcp.WithDescription("Link an existing answer document to a request"), mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), mcp.WithString("answer_id", mcp.Description("The answer entry ID"), mcp.Required()), ) } func listTasksTool() mcp.Tool { return mcp.NewTool("list_tasks", mcp.WithDescription("List tasks assigned to the current user"), mcp.WithString("project_id", mcp.Description("Optional project ID to filter tasks")), ) } // --- Tool handlers --- func listProjectsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } projects, err := lib.ProjectsByUser(db, cfg, userID) if err != nil { return mcp.NewToolResultError("Failed to list projects: " + err.Error()), nil } var lines []string for _, p := range projects { var data map[string]any if p.DataText != "" { _ = json.Unmarshal([]byte(p.DataText), &data) } name, _ := data["name"].(string) if name == "" { name = p.EntryID } lines = append(lines, fmt.Sprintf("- %s (id: %s, stage: %s)", name, p.EntryID, p.Stage)) } if len(lines) == 0 { return mcp.NewToolResultText("No projects found."), nil } return mcp.NewToolResultText("Projects:\n" + strings.Join(lines, "\n")), nil } } func getProjectHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } projectID, err := req.RequireString("project_id") if err != nil { return mcp.NewToolResultError("project_id is required"), nil } entry, err := lib.EntryByID(db, cfg, projectID) if err != nil { return mcp.NewToolResultError("Project not found"), nil } if entry.Type != lib.TypeProject { return mcp.NewToolResultError("Not a project"), nil } return mcp.NewToolResultText(fmt.Sprintf("Project: %s\nID: %s\nStage: %s\nCreated: %d\nData: %s", entry.SummaryText, entry.EntryID, entry.Stage, entry.CreatedAt, entry.DataText)), nil } } func listRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } projectID, err := req.RequireString("project_id") if err != nil { return mcp.NewToolResultError("project_id is required"), nil } // Verify access if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil { return mcp.NewToolResultError("Access denied"), nil } filter := lib.EntryFilter{ProjectID: projectID} entries, err := lib.EntryRead(db, cfg, userID, projectID, filter) if err != nil { return mcp.NewToolResultError("Failed to list entries: " + err.Error()), nil } // Filter to request types and build tree text var lines []string for _, e := range entries { if e.Type != lib.TypeRequestList && e.Type != lib.TypeSection && e.Type != lib.TypeRequest { continue } indent := strings.Repeat(" ", e.Depth) var data map[string]any if e.DataText != "" { _ = json.Unmarshal([]byte(e.DataText), &data) } switch e.Type { case lib.TypeRequestList: name, _ := data["name"].(string) lines = append(lines, fmt.Sprintf("%s[Request List] %s (id: %s)", indent, name, e.EntryID)) case lib.TypeSection: name, _ := data["name"].(string) lines = append(lines, fmt.Sprintf("%s[Section] %s (id: %s)", indent, name, e.EntryID)) case lib.TypeRequest: title, _ := data["title"].(string) status, _ := data["status"].(string) priority, _ := data["priority"].(string) lines = append(lines, fmt.Sprintf("%s- %s [%s/%s] (id: %s)", indent, title, status, priority, e.EntryID)) } } if len(lines) == 0 { return mcp.NewToolResultText("No requests found in this project."), nil } return mcp.NewToolResultText("Request tree:\n" + strings.Join(lines, "\n")), nil } } func getRequestHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } requestID, err := req.RequireString("request_id") if err != nil { return mcp.NewToolResultError("request_id is required"), nil } entry, err := lib.EntryByID(db, cfg, requestID) if err != nil { return mcp.NewToolResultError("Request not found"), nil } result := fmt.Sprintf("Request ID: %s\nProject: %s\nType: %s\nData: %s", entry.EntryID, entry.ProjectID, entry.Type, entry.DataText) // Get linked answers links, err := lib.AnswerLinksByRequest(db, requestID) if err == nil && len(links) > 0 { result += "\n\nLinked Answers:" for _, link := range links { answerEntry, err := lib.EntryByID(db, cfg, link.AnswerID) if err == nil { result += fmt.Sprintf("\n - %s (id: %s)", answerEntry.SummaryText, link.AnswerID) } } } return mcp.NewToolResultText(result), nil } } func searchRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } projectID, err := req.RequireString("project_id") if err != nil { return mcp.NewToolResultError("project_id is required"), nil } query, err := req.RequireString("query") if err != nil { return mcp.NewToolResultError("query is required"), nil } if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil { return mcp.NewToolResultError("Access denied"), nil } filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeRequest} entries, err := lib.EntryRead(db, cfg, userID, projectID, filter) if err != nil { return mcp.NewToolResultError("Failed to search: " + err.Error()), nil } queryLower := strings.ToLower(query) var lines []string for _, e := range entries { if strings.Contains(strings.ToLower(e.DataText), queryLower) || strings.Contains(strings.ToLower(e.SummaryText), queryLower) { var data map[string]any if e.DataText != "" { _ = json.Unmarshal([]byte(e.DataText), &data) } title, _ := data["title"].(string) status, _ := data["status"].(string) lines = append(lines, fmt.Sprintf("- %s [%s] (id: %s)", title, status, e.EntryID)) } } if len(lines) == 0 { return mcp.NewToolResultText("No requests matching '" + query + "' found."), nil } return mcp.NewToolResultText(fmt.Sprintf("Found %d results:\n%s", len(lines), strings.Join(lines, "\n"))), nil } } func updateRequestStatusHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } requestID, err := req.RequireString("request_id") if err != nil { return mcp.NewToolResultError("request_id is required"), nil } status, err := req.RequireString("status") if err != nil { return mcp.NewToolResultError("status is required"), nil } validStatuses := map[string]bool{"open": true, "assigned": true, "answered": true, "review": true, "published": true} if !validStatuses[status] { return mcp.NewToolResultError("Invalid status. Must be: open, assigned, answered, review, or published"), nil } entry, err := lib.EntryByID(db, cfg, requestID) if err != nil { return mcp.NewToolResultError("Request not found"), nil } if entry.Type != lib.TypeRequest { return mcp.NewToolResultError("Entry is not a request"), nil } // Parse current data, update status var data lib.RequestData if entry.DataText != "" { _ = json.Unmarshal([]byte(entry.DataText), &data) } data.Status = status dataJSON, _ := json.Marshal(data) entry.DataText = string(dataJSON) if err := lib.EntryWrite(db, cfg, userID, entry); err != nil { return mcp.NewToolResultError("Failed to update: " + err.Error()), nil } return mcp.NewToolResultText(fmt.Sprintf("Updated request %s status to '%s'", requestID, status)), nil } } func addCommentHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } requestID, err := req.RequireString("request_id") if err != nil { return mcp.NewToolResultError("request_id is required"), nil } comment, err := req.RequireString("comment") if err != nil { return mcp.NewToolResultError("comment is required"), nil } side, err := req.RequireString("side") if err != nil { return mcp.NewToolResultError("side is required"), nil } if side != "buyer" && side != "seller" { return mcp.NewToolResultError("side must be 'buyer' or 'seller'"), nil } entry, err := lib.EntryByID(db, cfg, requestID) if err != nil { return mcp.NewToolResultError("Request not found"), nil } if entry.Type != lib.TypeRequest { return mcp.NewToolResultError("Entry is not a request"), nil } var data lib.RequestData if entry.DataText != "" { _ = json.Unmarshal([]byte(entry.DataText), &data) } if side == "buyer" { data.BuyerComment = comment } else { data.SellerComment = comment } dataJSON, _ := json.Marshal(data) entry.DataText = string(dataJSON) if err := lib.EntryWrite(db, cfg, userID, entry); err != nil { return mcp.NewToolResultError("Failed to update: " + err.Error()), nil } return mcp.NewToolResultText(fmt.Sprintf("Added %s comment to request %s", side, requestID)), nil } } func listAnswersHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } projectID, err := req.RequireString("project_id") if err != nil { return mcp.NewToolResultError("project_id is required"), nil } if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil { return mcp.NewToolResultError("Access denied"), nil } filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeAnswer} entries, err := lib.EntryRead(db, cfg, userID, projectID, filter) if err != nil { return mcp.NewToolResultError("Failed to list answers: " + err.Error()), nil } var lines []string for _, e := range entries { lines = append(lines, fmt.Sprintf("- %s (id: %s)", e.SummaryText, e.EntryID)) } if len(lines) == 0 { return mcp.NewToolResultText("No answer documents found."), nil } return mcp.NewToolResultText("Answers:\n" + strings.Join(lines, "\n")), nil } } func linkAnswerHandler(db *lib.DB) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } requestID, err := req.RequireString("request_id") if err != nil { return mcp.NewToolResultError("request_id is required"), nil } answerID, err := req.RequireString("answer_id") if err != nil { return mcp.NewToolResultError("answer_id is required"), nil } if err := lib.AnswerLinkCreate(db, answerID, requestID, userID); err != nil { return mcp.NewToolResultError("Failed to link answer: " + err.Error()), nil } return mcp.NewToolResultText(fmt.Sprintf("Linked answer %s to request %s", answerID, requestID)), nil } } func listTasksHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { userID := UserIDFromContext(ctx) if userID == "" { return mcp.NewToolResultError("Not authenticated"), nil } tasks, err := lib.TasksByUser(db, cfg, userID) if err != nil { return mcp.NewToolResultError("Failed to list tasks: " + err.Error()), nil } projectID := req.GetString("project_id", "") var lines []string for _, t := range tasks { if projectID != "" && t.ProjectID != projectID { continue } lines = append(lines, fmt.Sprintf("- %s (id: %s, project: %s, stage: %s)", t.SummaryText, t.EntryID, t.ProjectID, t.Stage)) } if len(lines) == 0 { return mcp.NewToolResultText("No tasks assigned to you."), nil } return mcp.NewToolResultText("Your tasks:\n" + strings.Join(lines, "\n")), nil } }