diff --git a/api/mcp.go b/api/mcp.go new file mode 100644 index 0000000..6fd9135 --- /dev/null +++ b/api/mcp.go @@ -0,0 +1,508 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/mish/dealspace/lib" +) + +// NewMCPServer creates an MCP server with all Dealspace tools registered. +func NewMCPServer(db *lib.DB, cfg *lib.Config) *server.MCPServer { + s := server.NewMCPServer( + "dealspace", + "1.0.0", + server.WithToolCapabilities(false), + server.WithRecovery(), + ) + + s.AddTool(listProjectsTool(), listProjectsHandler(db, cfg)) + s.AddTool(getProjectTool(), getProjectHandler(db, cfg)) + s.AddTool(listRequestsTool(), listRequestsHandler(db, cfg)) + s.AddTool(getRequestTool(), getRequestHandler(db, cfg)) + s.AddTool(searchRequestsTool(), searchRequestsHandler(db, cfg)) + s.AddTool(updateRequestStatusTool(), updateRequestStatusHandler(db, cfg)) + s.AddTool(addCommentTool(), addCommentHandler(db, cfg)) + s.AddTool(listAnswersTool(), listAnswersHandler(db, cfg)) + s.AddTool(linkAnswerTool(), linkAnswerHandler(db)) + s.AddTool(listTasksTool(), listTasksHandler(db, cfg)) + + return s +} + +// MCPContextFunc propagates user_id from OAuth middleware into MCP tool context. +func MCPContextFunc(ctx context.Context, r interface{ Context() context.Context }) context.Context { + if uid := UserIDFromContext(r.Context()); uid != "" { + return context.WithValue(ctx, ctxUserID, uid) + } + return ctx +} + +// --- Tool definitions --- + +func listProjectsTool() mcp.Tool { + return mcp.NewTool("list_projects", + mcp.WithDescription("List all projects the user has access to"), + ) +} + +func getProjectTool() mcp.Tool { + return mcp.NewTool("get_project", + mcp.WithDescription("Get project details including workstreams and summary"), + mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), + ) +} + +func listRequestsTool() mcp.Tool { + return mcp.NewTool("list_requests", + mcp.WithDescription("List all requests in a project as a tree (request lists, sections, requests)"), + mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), + ) +} + +func getRequestTool() mcp.Tool { + return mcp.NewTool("get_request", + mcp.WithDescription("Get request details including linked answers"), + mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), + ) +} + +func searchRequestsTool() mcp.Tool { + return mcp.NewTool("search_requests", + mcp.WithDescription("Search requests by keyword in title, description, or comments"), + mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), + mcp.WithString("query", mcp.Description("Search keyword"), mcp.Required()), + ) +} + +func updateRequestStatusTool() mcp.Tool { + return mcp.NewTool("update_request_status", + mcp.WithDescription("Update the status of a request"), + mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), + mcp.WithString("status", mcp.Description("New status: open, in_process, partial, or complete"), mcp.Required()), + ) +} + +func addCommentTool() mcp.Tool { + return mcp.NewTool("add_comment", + mcp.WithDescription("Add a buyer or seller comment to a request"), + mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), + mcp.WithString("comment", mcp.Description("The comment text"), mcp.Required()), + mcp.WithString("side", mcp.Description("Which side: buyer or seller"), mcp.Required()), + ) +} + +func listAnswersTool() mcp.Tool { + return mcp.NewTool("list_answers", + mcp.WithDescription("List answer documents in a project"), + mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()), + ) +} + +func linkAnswerTool() mcp.Tool { + return mcp.NewTool("link_answer", + mcp.WithDescription("Link an existing answer document to a request"), + mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()), + mcp.WithString("answer_id", mcp.Description("The answer entry ID"), mcp.Required()), + ) +} + +func listTasksTool() mcp.Tool { + return mcp.NewTool("list_tasks", + mcp.WithDescription("List tasks assigned to the current user"), + mcp.WithString("project_id", mcp.Description("Optional project ID to filter tasks")), + ) +} + +// --- Tool handlers --- + +func listProjectsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + projects, err := lib.ProjectsByUser(db, cfg, userID) + if err != nil { + return mcp.NewToolResultError("Failed to list projects: " + err.Error()), nil + } + + var lines []string + for _, p := range projects { + var data map[string]any + if p.DataText != "" { + _ = json.Unmarshal([]byte(p.DataText), &data) + } + name, _ := data["name"].(string) + if name == "" { + name = p.EntryID + } + lines = append(lines, fmt.Sprintf("- %s (id: %s, stage: %s)", name, p.EntryID, p.Stage)) + } + + if len(lines) == 0 { + return mcp.NewToolResultText("No projects found."), nil + } + return mcp.NewToolResultText("Projects:\n" + strings.Join(lines, "\n")), nil + } +} + +func getProjectHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + projectID, err := req.RequireString("project_id") + if err != nil { + return mcp.NewToolResultError("project_id is required"), nil + } + + entry, err := lib.EntryByID(db, cfg, projectID) + if err != nil { + return mcp.NewToolResultError("Project not found"), nil + } + if entry.Type != lib.TypeProject { + return mcp.NewToolResultError("Not a project"), nil + } + + return mcp.NewToolResultText(fmt.Sprintf("Project: %s\nID: %s\nStage: %s\nCreated: %d\nData: %s", + entry.SummaryText, entry.EntryID, entry.Stage, entry.CreatedAt, entry.DataText)), nil + } +} + +func listRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + projectID, err := req.RequireString("project_id") + if err != nil { + return mcp.NewToolResultError("project_id is required"), nil + } + + // Verify access + if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil { + return mcp.NewToolResultError("Access denied"), nil + } + + filter := lib.EntryFilter{ProjectID: projectID} + entries, err := lib.EntryRead(db, cfg, userID, projectID, filter) + if err != nil { + return mcp.NewToolResultError("Failed to list entries: " + err.Error()), nil + } + + // Filter to request types and build tree text + var lines []string + for _, e := range entries { + if e.Type != lib.TypeRequestList && e.Type != lib.TypeSection && e.Type != lib.TypeRequest { + continue + } + indent := strings.Repeat(" ", e.Depth) + var data map[string]any + if e.DataText != "" { + _ = json.Unmarshal([]byte(e.DataText), &data) + } + + switch e.Type { + case lib.TypeRequestList: + name, _ := data["name"].(string) + lines = append(lines, fmt.Sprintf("%s[Request List] %s (id: %s)", indent, name, e.EntryID)) + case lib.TypeSection: + name, _ := data["name"].(string) + lines = append(lines, fmt.Sprintf("%s[Section] %s (id: %s)", indent, name, e.EntryID)) + case lib.TypeRequest: + title, _ := data["title"].(string) + status, _ := data["status"].(string) + priority, _ := data["priority"].(string) + lines = append(lines, fmt.Sprintf("%s- %s [%s/%s] (id: %s)", indent, title, status, priority, e.EntryID)) + } + } + + if len(lines) == 0 { + return mcp.NewToolResultText("No requests found in this project."), nil + } + return mcp.NewToolResultText("Request tree:\n" + strings.Join(lines, "\n")), nil + } +} + +func getRequestHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + requestID, err := req.RequireString("request_id") + if err != nil { + return mcp.NewToolResultError("request_id is required"), nil + } + + entry, err := lib.EntryByID(db, cfg, requestID) + if err != nil { + return mcp.NewToolResultError("Request not found"), nil + } + + result := fmt.Sprintf("Request ID: %s\nProject: %s\nType: %s\nData: %s", + entry.EntryID, entry.ProjectID, entry.Type, entry.DataText) + + // Get linked answers + links, err := lib.AnswerLinksByRequest(db, requestID) + if err == nil && len(links) > 0 { + result += "\n\nLinked Answers:" + for _, link := range links { + answerEntry, err := lib.EntryByID(db, cfg, link.AnswerID) + if err == nil { + result += fmt.Sprintf("\n - %s (id: %s)", answerEntry.SummaryText, link.AnswerID) + } + } + } + + return mcp.NewToolResultText(result), nil + } +} + +func searchRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + projectID, err := req.RequireString("project_id") + if err != nil { + return mcp.NewToolResultError("project_id is required"), nil + } + query, err := req.RequireString("query") + if err != nil { + return mcp.NewToolResultError("query is required"), nil + } + + if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil { + return mcp.NewToolResultError("Access denied"), nil + } + + filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeRequest} + entries, err := lib.EntryRead(db, cfg, userID, projectID, filter) + if err != nil { + return mcp.NewToolResultError("Failed to search: " + err.Error()), nil + } + + queryLower := strings.ToLower(query) + var lines []string + for _, e := range entries { + if strings.Contains(strings.ToLower(e.DataText), queryLower) || + strings.Contains(strings.ToLower(e.SummaryText), queryLower) { + var data map[string]any + if e.DataText != "" { + _ = json.Unmarshal([]byte(e.DataText), &data) + } + title, _ := data["title"].(string) + status, _ := data["status"].(string) + lines = append(lines, fmt.Sprintf("- %s [%s] (id: %s)", title, status, e.EntryID)) + } + } + + if len(lines) == 0 { + return mcp.NewToolResultText("No requests matching '" + query + "' found."), nil + } + return mcp.NewToolResultText(fmt.Sprintf("Found %d results:\n%s", len(lines), strings.Join(lines, "\n"))), nil + } +} + +func updateRequestStatusHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + requestID, err := req.RequireString("request_id") + if err != nil { + return mcp.NewToolResultError("request_id is required"), nil + } + status, err := req.RequireString("status") + if err != nil { + return mcp.NewToolResultError("status is required"), nil + } + + validStatuses := map[string]bool{"open": true, "in_process": true, "partial": true, "complete": true} + if !validStatuses[status] { + return mcp.NewToolResultError("Invalid status. Must be: open, in_process, partial, or complete"), nil + } + + entry, err := lib.EntryByID(db, cfg, requestID) + if err != nil { + return mcp.NewToolResultError("Request not found"), nil + } + if entry.Type != lib.TypeRequest { + return mcp.NewToolResultError("Entry is not a request"), nil + } + + // Parse current data, update status + var data lib.RequestData + if entry.DataText != "" { + _ = json.Unmarshal([]byte(entry.DataText), &data) + } + data.Status = status + + dataJSON, _ := json.Marshal(data) + entry.DataText = string(dataJSON) + + if err := lib.EntryWrite(db, cfg, userID, entry); err != nil { + return mcp.NewToolResultError("Failed to update: " + err.Error()), nil + } + + return mcp.NewToolResultText(fmt.Sprintf("Updated request %s status to '%s'", requestID, status)), nil + } +} + +func addCommentHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + requestID, err := req.RequireString("request_id") + if err != nil { + return mcp.NewToolResultError("request_id is required"), nil + } + comment, err := req.RequireString("comment") + if err != nil { + return mcp.NewToolResultError("comment is required"), nil + } + side, err := req.RequireString("side") + if err != nil { + return mcp.NewToolResultError("side is required"), nil + } + + if side != "buyer" && side != "seller" { + return mcp.NewToolResultError("side must be 'buyer' or 'seller'"), nil + } + + entry, err := lib.EntryByID(db, cfg, requestID) + if err != nil { + return mcp.NewToolResultError("Request not found"), nil + } + if entry.Type != lib.TypeRequest { + return mcp.NewToolResultError("Entry is not a request"), nil + } + + var data lib.RequestData + if entry.DataText != "" { + _ = json.Unmarshal([]byte(entry.DataText), &data) + } + + if side == "buyer" { + data.BuyerComment = comment + } else { + data.SellerComment = comment + } + + dataJSON, _ := json.Marshal(data) + entry.DataText = string(dataJSON) + + if err := lib.EntryWrite(db, cfg, userID, entry); err != nil { + return mcp.NewToolResultError("Failed to update: " + err.Error()), nil + } + + return mcp.NewToolResultText(fmt.Sprintf("Added %s comment to request %s", side, requestID)), nil + } +} + +func listAnswersHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + projectID, err := req.RequireString("project_id") + if err != nil { + return mcp.NewToolResultError("project_id is required"), nil + } + + if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil { + return mcp.NewToolResultError("Access denied"), nil + } + + filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeAnswer} + entries, err := lib.EntryRead(db, cfg, userID, projectID, filter) + if err != nil { + return mcp.NewToolResultError("Failed to list answers: " + err.Error()), nil + } + + var lines []string + for _, e := range entries { + lines = append(lines, fmt.Sprintf("- %s (id: %s)", e.SummaryText, e.EntryID)) + } + + if len(lines) == 0 { + return mcp.NewToolResultText("No answer documents found."), nil + } + return mcp.NewToolResultText("Answers:\n" + strings.Join(lines, "\n")), nil + } +} + +func linkAnswerHandler(db *lib.DB) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + requestID, err := req.RequireString("request_id") + if err != nil { + return mcp.NewToolResultError("request_id is required"), nil + } + answerID, err := req.RequireString("answer_id") + if err != nil { + return mcp.NewToolResultError("answer_id is required"), nil + } + + if err := lib.AnswerLinkCreate(db, answerID, requestID, userID); err != nil { + return mcp.NewToolResultError("Failed to link answer: " + err.Error()), nil + } + + return mcp.NewToolResultText(fmt.Sprintf("Linked answer %s to request %s", answerID, requestID)), nil + } +} + +func listTasksHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc { + return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + userID := UserIDFromContext(ctx) + if userID == "" { + return mcp.NewToolResultError("Not authenticated"), nil + } + + tasks, err := lib.TasksByUser(db, cfg, userID) + if err != nil { + return mcp.NewToolResultError("Failed to list tasks: " + err.Error()), nil + } + + projectID := req.GetString("project_id", "") + + var lines []string + for _, t := range tasks { + if projectID != "" && t.ProjectID != projectID { + continue + } + lines = append(lines, fmt.Sprintf("- %s (id: %s, project: %s, stage: %s)", + t.SummaryText, t.EntryID, t.ProjectID, t.Stage)) + } + + if len(lines) == 0 { + return mcp.NewToolResultText("No tasks assigned to you."), nil + } + return mcp.NewToolResultText("Your tasks:\n" + strings.Join(lines, "\n")), nil + } +} diff --git a/api/middleware.go b/api/middleware.go index 5786db0..832a764 100644 --- a/api/middleware.go +++ b/api/middleware.go @@ -60,6 +60,31 @@ func AuthMiddleware(db *lib.DB, jwtSecret []byte) func(http.Handler) http.Handle } } +// OAuthBearerAuth validates OAuth 2.0 bearer tokens for MCP endpoints. +func OAuthBearerAuth(db *lib.DB) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + w.Header().Set("WWW-Authenticate", "Bearer") + ErrorResponse(w, http.StatusUnauthorized, "missing_token", "Bearer token required") + return + } + tokenStr := strings.TrimPrefix(auth, "Bearer ") + + token, err := lib.OAuthTokenValidate(db, tokenStr) + if err != nil || token == nil { + w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"") + ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid or expired token") + return + } + + ctx := context.WithValue(r.Context(), ctxUserID, token.UserID) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + // LoggingMiddleware logs HTTP requests. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/api/oauth.go b/api/oauth.go new file mode 100644 index 0000000..8adfb9a --- /dev/null +++ b/api/oauth.go @@ -0,0 +1,417 @@ +package api + +import ( + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "html/template" + "net/http" + "net/url" + "os" + "path/filepath" + "strings" + "time" + + "github.com/mish/dealspace/lib" +) + +// OAuthHandlers holds dependencies for OAuth endpoints. +type OAuthHandlers struct { + DB *lib.DB + Cfg *lib.Config +} + +// NewOAuthHandlers creates OAuth handlers. +func NewOAuthHandlers(db *lib.DB, cfg *lib.Config) *OAuthHandlers { + return &OAuthHandlers{DB: db, Cfg: cfg} +} + +func baseURL(r *http.Request) string { + scheme := "https" + if r.TLS == nil && !strings.HasPrefix(r.Header.Get("X-Forwarded-Proto"), "https") { + scheme = "http" + } + return scheme + "://" + r.Host +} + +// Metadata serves GET /.well-known/oauth-authorization-server (RFC 8414). +func (o *OAuthHandlers) Metadata(w http.ResponseWriter, r *http.Request) { + base := baseURL(r) + JSONResponse(w, http.StatusOK, map[string]any{ + "issuer": base, + "authorization_endpoint": base + "/oauth/authorize", + "token_endpoint": base + "/oauth/token", + "revocation_endpoint": base + "/oauth/revoke", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code"}, + "code_challenge_methods_supported": []string{"S256"}, + "token_endpoint_auth_methods_supported": []string{"none"}, + }) +} + +// ResourceMetadata serves GET /.well-known/oauth-protected-resource (RFC 9728). +func (o *OAuthHandlers) ResourceMetadata(w http.ResponseWriter, r *http.Request) { + base := baseURL(r) + JSONResponse(w, http.StatusOK, map[string]any{ + "resource": base, + "authorization_servers": []string{base}, + }) +} + +// Authorize handles GET /oauth/authorize — shows consent page or redirects to login. +func (o *OAuthHandlers) Authorize(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + clientID := q.Get("client_id") + redirectURI := q.Get("redirect_uri") + responseType := q.Get("response_type") + codeChallenge := q.Get("code_challenge") + codeChallengeMethod := q.Get("code_challenge_method") + state := q.Get("state") + _ = q.Get("scope") // scope passed through via query string to POST handler + + // Validate required params + if responseType != "code" { + oauthError(w, redirectURI, state, "unsupported_response_type", "Only response_type=code is supported") + return + } + if codeChallenge == "" || codeChallengeMethod != "S256" { + oauthError(w, redirectURI, state, "invalid_request", "PKCE with S256 is required") + return + } + + // Validate client + client, err := lib.OAuthClientByID(o.DB, clientID) + if err != nil || client == nil { + oauthError(w, redirectURI, state, "invalid_client", "Unknown client_id") + return + } + + // Validate redirect_uri + if !validRedirectURI(client.RedirectURIs, redirectURI) { + ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri") + return + } + + // Check if user is authenticated via JWT in cookie or Authorization header + userID := o.extractUserID(r) + if userID == "" { + // Redirect to login with return URL + loginURL := "/app/login?next=" + url.QueryEscape(r.URL.RequestURI()) + http.Redirect(w, r, loginURL, http.StatusFound) + return + } + + // Show consent page + o.serveConsentPage(w, client.ClientName, r.URL.RequestURI()) +} + +// AuthorizeApprove handles POST /oauth/authorize — processes consent approval. +func (o *OAuthHandlers) AuthorizeApprove(w http.ResponseWriter, r *http.Request) { + // Parse the original authorize params from the form + if err := r.ParseForm(); err != nil { + ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid form data") + return + } + + // The form includes the original query string in a hidden field + originalQuery := r.FormValue("original_query") + parsedQuery, err := url.ParseQuery(originalQuery) + if err != nil { + ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid query parameters") + return + } + + clientID := parsedQuery.Get("client_id") + redirectURI := parsedQuery.Get("redirect_uri") + codeChallenge := parsedQuery.Get("code_challenge") + state := parsedQuery.Get("state") + scope := parsedQuery.Get("scope") + + // Check denial + if r.FormValue("action") == "deny" { + oauthRedirect(w, r, redirectURI, state, "", "access_denied") + return + } + + // Validate client again + client, err := lib.OAuthClientByID(o.DB, clientID) + if err != nil || client == nil { + ErrorResponse(w, http.StatusBadRequest, "invalid_client", "Unknown client_id") + return + } + if !validRedirectURI(client.RedirectURIs, redirectURI) { + ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri") + return + } + + // User must be authenticated + userID := o.extractUserID(r) + if userID == "" { + ErrorResponse(w, http.StatusUnauthorized, "login_required", "Authentication required") + return + } + + // Generate authorization code + codeBytes := make([]byte, 32) + if _, err := rand.Read(codeBytes); err != nil { + ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to generate code") + return + } + codeStr := hex.EncodeToString(codeBytes) + + oauthCode := &lib.OAuthCode{ + Code: codeStr, + ClientID: clientID, + UserID: userID, + RedirectURI: redirectURI, + CodeChallenge: codeChallenge, + Scope: scope, + ExpiresAt: time.Now().Add(10 * time.Minute).UnixMilli(), + Used: false, + } + if err := lib.OAuthCodeCreate(o.DB, oauthCode); err != nil { + ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to store authorization code") + return + } + + oauthRedirect(w, r, redirectURI, state, codeStr, "") +} + +// Token handles POST /oauth/token — exchanges authorization code for access token. +func (o *OAuthHandlers) Token(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + tokenError(w, "invalid_request", "Invalid form data") + return + } + + grantType := r.FormValue("grant_type") + if grantType != "authorization_code" { + tokenError(w, "unsupported_grant_type", "Only authorization_code is supported") + return + } + + codeStr := r.FormValue("code") + redirectURI := r.FormValue("redirect_uri") + clientID := r.FormValue("client_id") + codeVerifier := r.FormValue("code_verifier") + + if codeStr == "" || clientID == "" || codeVerifier == "" { + tokenError(w, "invalid_request", "Missing required parameters") + return + } + + // Consume the code (marks used, checks expiry) + code, err := lib.OAuthCodeConsume(o.DB, codeStr) + if err != nil { + tokenError(w, "invalid_grant", "Invalid or expired authorization code") + return + } + + // Verify client_id matches + if code.ClientID != clientID { + tokenError(w, "invalid_grant", "Client mismatch") + return + } + + // Verify redirect_uri matches + if code.RedirectURI != redirectURI { + tokenError(w, "invalid_grant", "Redirect URI mismatch") + return + } + + // Verify PKCE: SHA256(code_verifier) must match code_challenge + verifierHash := sha256.Sum256([]byte(codeVerifier)) + computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:]) + if computedChallenge != code.CodeChallenge { + tokenError(w, "invalid_grant", "PKCE verification failed") + return + } + + // Generate access token + tokenBytes := make([]byte, 64) + if _, err := rand.Read(tokenBytes); err != nil { + tokenError(w, "server_error", "Failed to generate token") + return + } + tokenStr := hex.EncodeToString(tokenBytes) + now := time.Now() + expiresIn := int64(24 * 60 * 60) // 24 hours in seconds + + oauthToken := &lib.OAuthToken{ + Token: tokenStr, + ClientID: clientID, + UserID: code.UserID, + Scope: code.Scope, + ExpiresAt: now.Add(24 * time.Hour).UnixMilli(), + Revoked: false, + CreatedAt: now.UnixMilli(), + } + if err := lib.OAuthTokenCreate(o.DB, oauthToken); err != nil { + tokenError(w, "server_error", "Failed to store token") + return + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Cache-Control", "no-store") + json.NewEncoder(w).Encode(map[string]any{ + "access_token": tokenStr, + "token_type": "Bearer", + "expires_in": expiresIn, + }) +} + +// Revoke handles POST /oauth/revoke — revokes an access token (RFC 7009). +func (o *OAuthHandlers) Revoke(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusOK) + return + } + tokenStr := r.FormValue("token") + if tokenStr != "" { + _ = lib.OAuthTokenRevoke(o.DB, tokenStr) + } + // Always return 200 per RFC 7009 + w.WriteHeader(http.StatusOK) +} + +// extractUserID checks JWT from Authorization header or ds_token cookie. +func (o *OAuthHandlers) extractUserID(r *http.Request) string { + // Try Authorization header first + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + token := strings.TrimPrefix(auth, "Bearer ") + claims, err := validateJWT(token, o.Cfg.JWTSecret) + if err == nil { + session, err := lib.SessionByID(o.DB, claims.SessionID) + if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() { + return claims.UserID + } + } + } + + // Try cookie + cookie, err := r.Cookie("ds_token") + if err == nil && cookie.Value != "" { + claims, err := validateJWT(cookie.Value, o.Cfg.JWTSecret) + if err == nil { + session, err := lib.SessionByID(o.DB, claims.SessionID) + if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() { + return claims.UserID + } + } + } + + return "" +} + +func (o *OAuthHandlers) serveConsentPage(w http.ResponseWriter, appName, authorizeURI string) { + // Extract query string from the authorize URI + parsed, _ := url.Parse(authorizeURI) + queryString := parsed.RawQuery + + candidates := []string{ + "portal/templates/auth/consent.html", + filepath.Join("/opt/dealspace/portal/templates/auth/consent.html"), + } + + var tmpl *template.Template + var err error + for _, p := range candidates { + if _, statErr := os.Stat(p); statErr == nil { + tmpl, err = template.ParseFiles(p) + if err == nil { + break + } + } + } + + if tmpl == nil { + http.Error(w, "Consent template not found", http.StatusInternalServerError) + return + } + + data := map[string]string{ + "AppName": appName, + "OriginalQuery": queryString, + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + tmpl.Execute(w, data) +} + +func validRedirectURI(registered []string, uri string) bool { + parsed, err := url.Parse(uri) + if err != nil { + return false + } + for _, r := range registered { + rParsed, err := url.Parse(r) + if err != nil { + continue + } + // Match scheme + host (port-agnostic for localhost) + if parsed.Scheme == rParsed.Scheme && parsed.Hostname() == rParsed.Hostname() { + return true + } + } + return false +} + +func oauthError(w http.ResponseWriter, redirectURI, state, errCode, errDesc string) { + if redirectURI == "" { + ErrorResponse(w, http.StatusBadRequest, errCode, errDesc) + return + } + oauthRedirect(w, nil, redirectURI, state, "", errCode) +} + +func oauthRedirect(w http.ResponseWriter, r *http.Request, redirectURI, state, code, errCode string) { + u, err := url.Parse(redirectURI) + if err != nil { + ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri") + return + } + q := u.Query() + if code != "" { + q.Set("code", code) + } + if errCode != "" { + q.Set("error", errCode) + } + if state != "" { + q.Set("state", state) + } + u.RawQuery = q.Encode() + if r != nil { + http.Redirect(w, r, u.String(), http.StatusFound) + } else { + w.Header().Set("Location", u.String()) + w.WriteHeader(http.StatusFound) + } +} + +func tokenError(w http.ResponseWriter, errCode, errDesc string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]string{ + "error": errCode, + "error_description": errDesc, + }) +} + +// SeedOAuthClient creates the default Claude OAuth client if it doesn't exist. +func SeedOAuthClient(db *lib.DB) { + client := &lib.OAuthClient{ + ClientID: "claude", + ClientName: "Claude", + RedirectURIs: []string{ + "http://localhost", + "http://127.0.0.1", + "http://localhost:0", + "http://127.0.0.1:0", + }, + CreatedAt: time.Now().UnixMilli(), + } + _ = lib.OAuthClientCreate(db, client) +} diff --git a/api/routes.go b/api/routes.go index 0ff67ef..9214adb 100644 --- a/api/routes.go +++ b/api/routes.go @@ -1,17 +1,20 @@ package api import ( + "context" "io/fs" "net/http" "github.com/go-chi/chi/v5" + "github.com/mark3labs/mcp-go/server" "github.com/mish/dealspace/lib" ) // NewRouter creates the main router with all routes registered. -func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs.FS, portalFS fs.FS) *chi.Mux { +func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs.FS, portalFS fs.FS, mcpServer *server.MCPServer) *chi.Mux { r := chi.NewRouter() h := NewHandlers(db, cfg, store) + oauth := NewOAuthHandlers(db, cfg) // Global middleware r.Use(LoggingMiddleware) @@ -103,6 +106,33 @@ func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs. r.Post("/admin/impersonate", h.AdminImpersonate) }) + // OAuth metadata (unauthenticated, auto-discovery) + r.Get("/.well-known/oauth-authorization-server", oauth.Metadata) + r.Get("/.well-known/oauth-protected-resource", oauth.ResourceMetadata) + + // OAuth endpoints + r.Get("/oauth/authorize", oauth.Authorize) + r.Post("/oauth/authorize", oauth.AuthorizeApprove) + r.Post("/oauth/token", oauth.Token) + r.Post("/oauth/revoke", oauth.Revoke) + + // MCP endpoint (OAuth bearer auth) + if mcpServer != nil { + mcpHTTP := server.NewStreamableHTTPServer(mcpServer, + server.WithEndpointPath("/mcp"), + server.WithHTTPContextFunc(func(ctx context.Context, r *http.Request) context.Context { + if uid := UserIDFromContext(r.Context()); uid != "" { + return context.WithValue(ctx, ctxUserID, uid) + } + return ctx + }), + ) + r.Group(func(r chi.Router) { + r.Use(OAuthBearerAuth(db)) + r.Handle("/mcp", mcpHTTP) + }) + } + // Portal app routes (serve templates, auth checked client-side via JS) r.Get("/app", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/app/projects", http.StatusFound) diff --git a/cmd/server/main.go b/cmd/server/main.go index 22b1865..6da88e4 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -40,6 +40,9 @@ func main() { // Always seed super admin accounts on startup seedSuperAdmins(db) + // Seed OAuth client for Claude + api.SeedOAuthClient(db) + // Seed demo data if SEED_DEMO=true if os.Getenv("SEED_DEMO") == "true" { seedDemoData(db, cfg) @@ -51,7 +54,10 @@ func main() { log.Fatalf("embed website: %v", err) } - router := api.NewRouter(db, cfg, store, websiteFS, nil) + // MCP server for Claude integration + mcpServer := api.NewMCPServer(db, cfg) + + router := api.NewRouter(db, cfg, store, websiteFS, nil, mcpServer) addr := ":" + cfg.Port log.Printf("dealspace starting on %s (env=%s)", addr, cfg.Env) diff --git a/go.mod b/go.mod index 316a206..d306068 100644 --- a/go.mod +++ b/go.mod @@ -6,26 +6,35 @@ require ( github.com/go-chi/chi/v5 v5.2.1 github.com/google/uuid v1.6.0 github.com/klauspost/compress v1.18.0 + github.com/mark3labs/mcp-go v0.44.1 github.com/mattn/go-sqlite3 v1.14.24 github.com/pdfcpu/pdfcpu v0.11.1 + github.com/xuri/excelize/v2 v2.10.1 golang.org/x/crypto v0.48.0 ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/clipperhouse/uax29/v2 v2.2.0 // indirect github.com/hhrutter/lzw v1.0.0 // indirect github.com/hhrutter/pkcs7 v0.2.0 // indirect github.com/hhrutter/tiff v1.0.2 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-runewidth v0.0.19 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/richardlehane/mscfb v1.0.6 // indirect github.com/richardlehane/msoleps v1.0.6 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/tiendc/go-deepcopy v1.7.2 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xuri/efp v0.0.1 // indirect - github.com/xuri/excelize/v2 v2.10.1 // indirect github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect golang.org/x/image v0.32.0 // indirect golang.org/x/net v0.50.0 // indirect golang.org/x/text v0.34.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0890211..73ecf47 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,17 @@ +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY= github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8= github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hhrutter/lzw v1.0.0 h1:laL89Llp86W3rRs83LvKbwYRx6INE8gDn0XNb1oXtm0= @@ -10,8 +20,19 @@ github.com/hhrutter/pkcs7 v0.2.0 h1:i4HN2XMbGQpZRnKBLsUwO3dSckzgX142TNqY/KfXg+I= github.com/hhrutter/pkcs7 v0.2.0/go.mod h1:aEzKz0+ZAlz7YaEMY47jDHL14hVWD6iXt0AgqgAvWgE= github.com/hhrutter/tiff v1.0.2 h1:7H3FQQpKu/i5WaSChoD1nnJbGx4MxU5TlNqqpxw55z8= github.com/hhrutter/tiff v1.0.2/go.mod h1:pcOeuK5loFUE7Y/WnzGw20YxUdnqjY1P0Jlcieb/cCw= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.44.1 h1:2PKppYlT9X2fXnE8SNYQLAX4hNjfPB0oNLqQVcN6mE8= +github.com/mark3labs/mcp-go v0.44.1/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw= github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= @@ -20,31 +41,41 @@ github.com/pdfcpu/pdfcpu v0.11.1 h1:htHBSkGH5jMKWC6e0sihBFbcKZ8vG1M67c8/dJxhjas= github.com/pdfcpu/pdfcpu v0.11.1/go.mod h1:pP3aGga7pRvwFWAm9WwFvo+V68DfANi9kxSQYioNYcw= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/richardlehane/mscfb v1.0.6 h1:eN3bvvZCp00bs7Zf52bxNwAx5lJDBK1tCuH19qq5aC8= github.com/richardlehane/mscfb v1.0.6/go.mod h1:pe0+IUIc0AHh0+teNzBlJCtSyZdFOGgV4ZK9bsoV+Jo= github.com/richardlehane/msoleps v1.0.6 h1:9BvkpjvD+iUBalUY4esMwv6uBkfOip/Lzvd93jvR9gg= github.com/richardlehane/msoleps v1.0.6/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tiendc/go-deepcopy v1.7.2 h1:Ut2yYR7W9tWjTQitganoIue4UGxZwCcJy3orjrrIj44= github.com/tiendc/go-deepcopy v1.7.2/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8= github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI= github.com/xuri/excelize/v2 v2.10.1 h1:V62UlqopMqha3kOpnlHy2CcRVw1V8E63jFoWUmMzxN0= github.com/xuri/excelize/v2 v2.10.1/go.mod h1:iG5tARpgaEeIhTqt3/fgXCGoBRt4hNXgCp3tfXKoOIc= github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE= github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ= -golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04= -golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/image v0.32.0 h1:6lZQWq75h7L5IWNk0r+SCpUJ6tUVd3v4ZHnbRKLkUDQ= golang.org/x/image v0.32.0/go.mod h1:/R37rrQmKXtO6tYXAjtDLwQgFLHmhW+V6ayXlxzP2Pc= golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60= golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM= -golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= -golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/lib/dbcore.go b/lib/dbcore.go index 27b164d..b1358ba 100644 --- a/lib/dbcore.go +++ b/lib/dbcore.go @@ -3,6 +3,7 @@ package lib import ( "crypto/subtle" "database/sql" + "encoding/json" "errors" "fmt" "os" @@ -1003,3 +1004,99 @@ func SoftDeleteTree(db *DB, entryID, actorID string) error { } return nil } + +// --- OAuth --- + +// OAuthClientByID looks up an OAuth client by its client_id. +func OAuthClientByID(db *DB, clientID string) (*OAuthClient, error) { + row := db.Conn.QueryRow(`SELECT client_id, client_name, redirect_uris, created_at FROM oauth_clients WHERE client_id = ?`, clientID) + var c OAuthClient + var uris string + if err := row.Scan(&c.ClientID, &c.ClientName, &uris, &c.CreatedAt); err != nil { + return nil, err + } + _ = json.Unmarshal([]byte(uris), &c.RedirectURIs) + return &c, nil +} + +// OAuthClientCreate inserts a new OAuth client (skips if already exists). +func OAuthClientCreate(db *DB, c *OAuthClient) error { + uris, _ := json.Marshal(c.RedirectURIs) + _, err := db.Conn.Exec( + `INSERT OR IGNORE INTO oauth_clients (client_id, client_name, redirect_uris, created_at) VALUES (?, ?, ?, ?)`, + c.ClientID, c.ClientName, string(uris), c.CreatedAt, + ) + return err +} + +// OAuthCodeCreate stores an authorization code. +func OAuthCodeCreate(db *DB, code *OAuthCode) error { + used := 0 + if code.Used { + used = 1 + } + _, err := db.Conn.Exec( + `INSERT INTO oauth_codes (code, client_id, user_id, redirect_uri, code_challenge, scope, expires_at, used) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + code.Code, code.ClientID, code.UserID, code.RedirectURI, code.CodeChallenge, code.Scope, code.ExpiresAt, used, + ) + return err +} + +// OAuthCodeConsume looks up and marks an authorization code as used. Returns nil if invalid/expired/used. +func OAuthCodeConsume(db *DB, codeStr string) (*OAuthCode, error) { + row := db.Conn.QueryRow( + `SELECT code, client_id, user_id, redirect_uri, code_challenge, scope, expires_at, used FROM oauth_codes WHERE code = ?`, + codeStr, + ) + var c OAuthCode + var used int + if err := row.Scan(&c.Code, &c.ClientID, &c.UserID, &c.RedirectURI, &c.CodeChallenge, &c.Scope, &c.ExpiresAt, &used); err != nil { + return nil, err + } + if used == 1 || c.ExpiresAt < time.Now().UnixMilli() { + return nil, ErrAccessDenied + } + // Mark as used + _, err := db.Conn.Exec(`UPDATE oauth_codes SET used = 1 WHERE code = ?`, codeStr) + if err != nil { + return nil, err + } + return &c, nil +} + +// OAuthTokenCreate stores an access token. +func OAuthTokenCreate(db *DB, t *OAuthToken) error { + revoked := 0 + if t.Revoked { + revoked = 1 + } + _, err := db.Conn.Exec( + `INSERT INTO oauth_tokens (token, client_id, user_id, scope, expires_at, revoked, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`, + t.Token, t.ClientID, t.UserID, t.Scope, t.ExpiresAt, revoked, t.CreatedAt, + ) + return err +} + +// OAuthTokenValidate returns a token if it's valid (not expired, not revoked). +func OAuthTokenValidate(db *DB, tokenStr string) (*OAuthToken, error) { + row := db.Conn.QueryRow( + `SELECT token, client_id, user_id, scope, expires_at, revoked, created_at FROM oauth_tokens WHERE token = ?`, + tokenStr, + ) + var t OAuthToken + var revoked int + if err := row.Scan(&t.Token, &t.ClientID, &t.UserID, &t.Scope, &t.ExpiresAt, &revoked, &t.CreatedAt); err != nil { + return nil, err + } + t.Revoked = revoked == 1 + if t.Revoked || t.ExpiresAt < time.Now().UnixMilli() { + return nil, ErrAccessDenied + } + return &t, nil +} + +// OAuthTokenRevoke marks a token as revoked. +func OAuthTokenRevoke(db *DB, tokenStr string) error { + _, err := db.Conn.Exec(`UPDATE oauth_tokens SET revoked = 1 WHERE token = ?`, tokenStr) + return err +} diff --git a/lib/types.go b/lib/types.go index 90cc3fa..f21cf4f 100644 --- a/lib/types.go +++ b/lib/types.go @@ -270,3 +270,34 @@ type WorkstreamData struct { Name string `json:"name"` Description string `json:"description,omitempty"` } + +// OAuthClient represents a registered OAuth 2.0 client application. +type OAuthClient struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name"` + RedirectURIs []string `json:"redirect_uris"` + CreatedAt int64 `json:"created_at"` +} + +// OAuthCode represents an OAuth 2.0 authorization code. +type OAuthCode struct { + Code string `json:"code"` + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + RedirectURI string `json:"redirect_uri"` + CodeChallenge string `json:"code_challenge"` + Scope string `json:"scope"` + ExpiresAt int64 `json:"expires_at"` + Used bool `json:"used"` +} + +// OAuthToken represents an OAuth 2.0 access token. +type OAuthToken struct { + Token string `json:"token"` + ClientID string `json:"client_id"` + UserID string `json:"user_id"` + Scope string `json:"scope"` + ExpiresAt int64 `json:"expires_at"` + Revoked bool `json:"revoked"` + CreatedAt int64 `json:"created_at"` +} diff --git a/migrations/003_oauth.sql b/migrations/003_oauth.sql new file mode 100644 index 0000000..e77a3af --- /dev/null +++ b/migrations/003_oauth.sql @@ -0,0 +1,27 @@ +CREATE TABLE IF NOT EXISTS oauth_clients ( + client_id TEXT PRIMARY KEY, + client_name TEXT NOT NULL, + redirect_uris TEXT NOT NULL, + created_at INTEGER NOT NULL +); + +CREATE TABLE IF NOT EXISTS oauth_codes ( + code TEXT PRIMARY KEY, + client_id TEXT NOT NULL, + user_id TEXT NOT NULL, + redirect_uri TEXT NOT NULL, + code_challenge TEXT NOT NULL, + scope TEXT NOT NULL DEFAULT '', + expires_at INTEGER NOT NULL, + used INTEGER NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS oauth_tokens ( + token TEXT PRIMARY KEY, + client_id TEXT NOT NULL, + user_id TEXT NOT NULL, + scope TEXT NOT NULL DEFAULT '', + expires_at INTEGER NOT NULL, + revoked INTEGER NOT NULL DEFAULT 0, + created_at INTEGER NOT NULL +); diff --git a/portal/templates/auth/consent.html b/portal/templates/auth/consent.html new file mode 100644 index 0000000..f0a9673 --- /dev/null +++ b/portal/templates/auth/consent.html @@ -0,0 +1,58 @@ + + +
+ + +This application is requesting access to your Dealspace account.
+You can revoke access at any time from your account settings.
+