509 lines
16 KiB
Go
509 lines
16 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/mark3labs/mcp-go/server"
|
|
"github.com/mish/dealspace/lib"
|
|
)
|
|
|
|
// NewMCPServer creates an MCP server with all Dealspace tools registered.
|
|
func NewMCPServer(db *lib.DB, cfg *lib.Config) *server.MCPServer {
|
|
s := server.NewMCPServer(
|
|
"dealspace",
|
|
"1.0.0",
|
|
server.WithToolCapabilities(false),
|
|
server.WithRecovery(),
|
|
)
|
|
|
|
s.AddTool(listProjectsTool(), listProjectsHandler(db, cfg))
|
|
s.AddTool(getProjectTool(), getProjectHandler(db, cfg))
|
|
s.AddTool(listRequestsTool(), listRequestsHandler(db, cfg))
|
|
s.AddTool(getRequestTool(), getRequestHandler(db, cfg))
|
|
s.AddTool(searchRequestsTool(), searchRequestsHandler(db, cfg))
|
|
s.AddTool(updateRequestStatusTool(), updateRequestStatusHandler(db, cfg))
|
|
s.AddTool(addCommentTool(), addCommentHandler(db, cfg))
|
|
s.AddTool(listAnswersTool(), listAnswersHandler(db, cfg))
|
|
s.AddTool(linkAnswerTool(), linkAnswerHandler(db))
|
|
s.AddTool(listTasksTool(), listTasksHandler(db, cfg))
|
|
|
|
return s
|
|
}
|
|
|
|
// MCPContextFunc propagates user_id from OAuth middleware into MCP tool context.
|
|
func MCPContextFunc(ctx context.Context, r interface{ Context() context.Context }) context.Context {
|
|
if uid := UserIDFromContext(r.Context()); uid != "" {
|
|
return context.WithValue(ctx, ctxUserID, uid)
|
|
}
|
|
return ctx
|
|
}
|
|
|
|
// --- Tool definitions ---
|
|
|
|
func listProjectsTool() mcp.Tool {
|
|
return mcp.NewTool("list_projects",
|
|
mcp.WithDescription("List all projects the user has access to"),
|
|
)
|
|
}
|
|
|
|
func getProjectTool() mcp.Tool {
|
|
return mcp.NewTool("get_project",
|
|
mcp.WithDescription("Get project details including workstreams and summary"),
|
|
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func listRequestsTool() mcp.Tool {
|
|
return mcp.NewTool("list_requests",
|
|
mcp.WithDescription("List all requests in a project as a tree (request lists, sections, requests)"),
|
|
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func getRequestTool() mcp.Tool {
|
|
return mcp.NewTool("get_request",
|
|
mcp.WithDescription("Get request details including linked answers"),
|
|
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func searchRequestsTool() mcp.Tool {
|
|
return mcp.NewTool("search_requests",
|
|
mcp.WithDescription("Search requests by keyword in title, description, or comments"),
|
|
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
|
|
mcp.WithString("query", mcp.Description("Search keyword"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func updateRequestStatusTool() mcp.Tool {
|
|
return mcp.NewTool("update_request_status",
|
|
mcp.WithDescription("Update the status of a request"),
|
|
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
|
|
mcp.WithString("status", mcp.Description("New status: open, assigned, answered, review, or published"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func addCommentTool() mcp.Tool {
|
|
return mcp.NewTool("add_comment",
|
|
mcp.WithDescription("Add a buyer or seller comment to a request"),
|
|
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
|
|
mcp.WithString("comment", mcp.Description("The comment text"), mcp.Required()),
|
|
mcp.WithString("side", mcp.Description("Which side: buyer or seller"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func listAnswersTool() mcp.Tool {
|
|
return mcp.NewTool("list_answers",
|
|
mcp.WithDescription("List answer documents in a project"),
|
|
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func linkAnswerTool() mcp.Tool {
|
|
return mcp.NewTool("link_answer",
|
|
mcp.WithDescription("Link an existing answer document to a request"),
|
|
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
|
|
mcp.WithString("answer_id", mcp.Description("The answer entry ID"), mcp.Required()),
|
|
)
|
|
}
|
|
|
|
func listTasksTool() mcp.Tool {
|
|
return mcp.NewTool("list_tasks",
|
|
mcp.WithDescription("List tasks assigned to the current user"),
|
|
mcp.WithString("project_id", mcp.Description("Optional project ID to filter tasks")),
|
|
)
|
|
}
|
|
|
|
// --- Tool handlers ---
|
|
|
|
func listProjectsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
projects, err := lib.ProjectsByUser(db, cfg, userID)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Failed to list projects: " + err.Error()), nil
|
|
}
|
|
|
|
var lines []string
|
|
for _, p := range projects {
|
|
var data map[string]any
|
|
if p.DataText != "" {
|
|
_ = json.Unmarshal([]byte(p.DataText), &data)
|
|
}
|
|
name, _ := data["name"].(string)
|
|
if name == "" {
|
|
name = p.EntryID
|
|
}
|
|
lines = append(lines, fmt.Sprintf("- %s (id: %s, stage: %s)", name, p.EntryID, p.Stage))
|
|
}
|
|
|
|
if len(lines) == 0 {
|
|
return mcp.NewToolResultText("No projects found."), nil
|
|
}
|
|
return mcp.NewToolResultText("Projects:\n" + strings.Join(lines, "\n")), nil
|
|
}
|
|
}
|
|
|
|
func getProjectHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
projectID, err := req.RequireString("project_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("project_id is required"), nil
|
|
}
|
|
|
|
entry, err := lib.EntryByID(db, cfg, projectID)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Project not found"), nil
|
|
}
|
|
if entry.Type != lib.TypeProject {
|
|
return mcp.NewToolResultError("Not a project"), nil
|
|
}
|
|
|
|
return mcp.NewToolResultText(fmt.Sprintf("Project: %s\nID: %s\nStage: %s\nCreated: %d\nData: %s",
|
|
entry.SummaryText, entry.EntryID, entry.Stage, entry.CreatedAt, entry.DataText)), nil
|
|
}
|
|
}
|
|
|
|
func listRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
projectID, err := req.RequireString("project_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("project_id is required"), nil
|
|
}
|
|
|
|
// Verify access
|
|
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
|
|
return mcp.NewToolResultError("Access denied"), nil
|
|
}
|
|
|
|
filter := lib.EntryFilter{ProjectID: projectID}
|
|
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Failed to list entries: " + err.Error()), nil
|
|
}
|
|
|
|
// Filter to request types and build tree text
|
|
var lines []string
|
|
for _, e := range entries {
|
|
if e.Type != lib.TypeRequestList && e.Type != lib.TypeSection && e.Type != lib.TypeRequest {
|
|
continue
|
|
}
|
|
indent := strings.Repeat(" ", e.Depth)
|
|
var data map[string]any
|
|
if e.DataText != "" {
|
|
_ = json.Unmarshal([]byte(e.DataText), &data)
|
|
}
|
|
|
|
switch e.Type {
|
|
case lib.TypeRequestList:
|
|
name, _ := data["name"].(string)
|
|
lines = append(lines, fmt.Sprintf("%s[Request List] %s (id: %s)", indent, name, e.EntryID))
|
|
case lib.TypeSection:
|
|
name, _ := data["name"].(string)
|
|
lines = append(lines, fmt.Sprintf("%s[Section] %s (id: %s)", indent, name, e.EntryID))
|
|
case lib.TypeRequest:
|
|
title, _ := data["title"].(string)
|
|
status, _ := data["status"].(string)
|
|
priority, _ := data["priority"].(string)
|
|
lines = append(lines, fmt.Sprintf("%s- %s [%s/%s] (id: %s)", indent, title, status, priority, e.EntryID))
|
|
}
|
|
}
|
|
|
|
if len(lines) == 0 {
|
|
return mcp.NewToolResultText("No requests found in this project."), nil
|
|
}
|
|
return mcp.NewToolResultText("Request tree:\n" + strings.Join(lines, "\n")), nil
|
|
}
|
|
}
|
|
|
|
func getRequestHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
requestID, err := req.RequireString("request_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("request_id is required"), nil
|
|
}
|
|
|
|
entry, err := lib.EntryByID(db, cfg, requestID)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Request not found"), nil
|
|
}
|
|
|
|
result := fmt.Sprintf("Request ID: %s\nProject: %s\nType: %s\nData: %s",
|
|
entry.EntryID, entry.ProjectID, entry.Type, entry.DataText)
|
|
|
|
// Get linked answers
|
|
links, err := lib.AnswerLinksByRequest(db, requestID)
|
|
if err == nil && len(links) > 0 {
|
|
result += "\n\nLinked Answers:"
|
|
for _, link := range links {
|
|
answerEntry, err := lib.EntryByID(db, cfg, link.AnswerID)
|
|
if err == nil {
|
|
result += fmt.Sprintf("\n - %s (id: %s)", answerEntry.SummaryText, link.AnswerID)
|
|
}
|
|
}
|
|
}
|
|
|
|
return mcp.NewToolResultText(result), nil
|
|
}
|
|
}
|
|
|
|
func searchRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
projectID, err := req.RequireString("project_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("project_id is required"), nil
|
|
}
|
|
query, err := req.RequireString("query")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("query is required"), nil
|
|
}
|
|
|
|
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
|
|
return mcp.NewToolResultError("Access denied"), nil
|
|
}
|
|
|
|
filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeRequest}
|
|
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Failed to search: " + err.Error()), nil
|
|
}
|
|
|
|
queryLower := strings.ToLower(query)
|
|
var lines []string
|
|
for _, e := range entries {
|
|
if strings.Contains(strings.ToLower(e.DataText), queryLower) ||
|
|
strings.Contains(strings.ToLower(e.SummaryText), queryLower) {
|
|
var data map[string]any
|
|
if e.DataText != "" {
|
|
_ = json.Unmarshal([]byte(e.DataText), &data)
|
|
}
|
|
title, _ := data["title"].(string)
|
|
status, _ := data["status"].(string)
|
|
lines = append(lines, fmt.Sprintf("- %s [%s] (id: %s)", title, status, e.EntryID))
|
|
}
|
|
}
|
|
|
|
if len(lines) == 0 {
|
|
return mcp.NewToolResultText("No requests matching '" + query + "' found."), nil
|
|
}
|
|
return mcp.NewToolResultText(fmt.Sprintf("Found %d results:\n%s", len(lines), strings.Join(lines, "\n"))), nil
|
|
}
|
|
}
|
|
|
|
func updateRequestStatusHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
requestID, err := req.RequireString("request_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("request_id is required"), nil
|
|
}
|
|
status, err := req.RequireString("status")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("status is required"), nil
|
|
}
|
|
|
|
validStatuses := map[string]bool{"open": true, "assigned": true, "answered": true, "review": true, "published": true}
|
|
if !validStatuses[status] {
|
|
return mcp.NewToolResultError("Invalid status. Must be: open, assigned, answered, review, or published"), nil
|
|
}
|
|
|
|
entry, err := lib.EntryByID(db, cfg, requestID)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Request not found"), nil
|
|
}
|
|
if entry.Type != lib.TypeRequest {
|
|
return mcp.NewToolResultError("Entry is not a request"), nil
|
|
}
|
|
|
|
// Parse current data, update status
|
|
var data lib.RequestData
|
|
if entry.DataText != "" {
|
|
_ = json.Unmarshal([]byte(entry.DataText), &data)
|
|
}
|
|
data.Status = status
|
|
|
|
dataJSON, _ := json.Marshal(data)
|
|
entry.DataText = string(dataJSON)
|
|
|
|
if err := lib.EntryWrite(db, cfg, userID, entry); err != nil {
|
|
return mcp.NewToolResultError("Failed to update: " + err.Error()), nil
|
|
}
|
|
|
|
return mcp.NewToolResultText(fmt.Sprintf("Updated request %s status to '%s'", requestID, status)), nil
|
|
}
|
|
}
|
|
|
|
func addCommentHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
requestID, err := req.RequireString("request_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("request_id is required"), nil
|
|
}
|
|
comment, err := req.RequireString("comment")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("comment is required"), nil
|
|
}
|
|
side, err := req.RequireString("side")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("side is required"), nil
|
|
}
|
|
|
|
if side != "buyer" && side != "seller" {
|
|
return mcp.NewToolResultError("side must be 'buyer' or 'seller'"), nil
|
|
}
|
|
|
|
entry, err := lib.EntryByID(db, cfg, requestID)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Request not found"), nil
|
|
}
|
|
if entry.Type != lib.TypeRequest {
|
|
return mcp.NewToolResultError("Entry is not a request"), nil
|
|
}
|
|
|
|
var data lib.RequestData
|
|
if entry.DataText != "" {
|
|
_ = json.Unmarshal([]byte(entry.DataText), &data)
|
|
}
|
|
|
|
if side == "buyer" {
|
|
data.BuyerComment = comment
|
|
} else {
|
|
data.SellerComment = comment
|
|
}
|
|
|
|
dataJSON, _ := json.Marshal(data)
|
|
entry.DataText = string(dataJSON)
|
|
|
|
if err := lib.EntryWrite(db, cfg, userID, entry); err != nil {
|
|
return mcp.NewToolResultError("Failed to update: " + err.Error()), nil
|
|
}
|
|
|
|
return mcp.NewToolResultText(fmt.Sprintf("Added %s comment to request %s", side, requestID)), nil
|
|
}
|
|
}
|
|
|
|
func listAnswersHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
projectID, err := req.RequireString("project_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("project_id is required"), nil
|
|
}
|
|
|
|
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
|
|
return mcp.NewToolResultError("Access denied"), nil
|
|
}
|
|
|
|
filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeAnswer}
|
|
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Failed to list answers: " + err.Error()), nil
|
|
}
|
|
|
|
var lines []string
|
|
for _, e := range entries {
|
|
lines = append(lines, fmt.Sprintf("- %s (id: %s)", e.SummaryText, e.EntryID))
|
|
}
|
|
|
|
if len(lines) == 0 {
|
|
return mcp.NewToolResultText("No answer documents found."), nil
|
|
}
|
|
return mcp.NewToolResultText("Answers:\n" + strings.Join(lines, "\n")), nil
|
|
}
|
|
}
|
|
|
|
func linkAnswerHandler(db *lib.DB) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
requestID, err := req.RequireString("request_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("request_id is required"), nil
|
|
}
|
|
answerID, err := req.RequireString("answer_id")
|
|
if err != nil {
|
|
return mcp.NewToolResultError("answer_id is required"), nil
|
|
}
|
|
|
|
if err := lib.AnswerLinkCreate(db, answerID, requestID, userID); err != nil {
|
|
return mcp.NewToolResultError("Failed to link answer: " + err.Error()), nil
|
|
}
|
|
|
|
return mcp.NewToolResultText(fmt.Sprintf("Linked answer %s to request %s", answerID, requestID)), nil
|
|
}
|
|
}
|
|
|
|
func listTasksHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
|
|
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
|
userID := UserIDFromContext(ctx)
|
|
if userID == "" {
|
|
return mcp.NewToolResultError("Not authenticated"), nil
|
|
}
|
|
|
|
tasks, err := lib.TasksByUser(db, cfg, userID)
|
|
if err != nil {
|
|
return mcp.NewToolResultError("Failed to list tasks: " + err.Error()), nil
|
|
}
|
|
|
|
projectID := req.GetString("project_id", "")
|
|
|
|
var lines []string
|
|
for _, t := range tasks {
|
|
if projectID != "" && t.ProjectID != projectID {
|
|
continue
|
|
}
|
|
lines = append(lines, fmt.Sprintf("- %s (id: %s, project: %s, stage: %s)",
|
|
t.SummaryText, t.EntryID, t.ProjectID, t.Stage))
|
|
}
|
|
|
|
if len(lines) == 0 {
|
|
return mcp.NewToolResultText("No tasks assigned to you."), nil
|
|
}
|
|
return mcp.NewToolResultText("Your tasks:\n" + strings.Join(lines, "\n")), nil
|
|
}
|
|
}
|