chore: auto-commit uncommitted changes

This commit is contained in:
James 2026-03-04 00:01:22 -05:00
parent d52921e1f3
commit 27c715f963
12 changed files with 1258 additions and 10 deletions

508
api/mcp.go Normal file
View File

@ -0,0 +1,508 @@
package api
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
"github.com/mish/dealspace/lib"
)
// NewMCPServer creates an MCP server with all Dealspace tools registered.
func NewMCPServer(db *lib.DB, cfg *lib.Config) *server.MCPServer {
s := server.NewMCPServer(
"dealspace",
"1.0.0",
server.WithToolCapabilities(false),
server.WithRecovery(),
)
s.AddTool(listProjectsTool(), listProjectsHandler(db, cfg))
s.AddTool(getProjectTool(), getProjectHandler(db, cfg))
s.AddTool(listRequestsTool(), listRequestsHandler(db, cfg))
s.AddTool(getRequestTool(), getRequestHandler(db, cfg))
s.AddTool(searchRequestsTool(), searchRequestsHandler(db, cfg))
s.AddTool(updateRequestStatusTool(), updateRequestStatusHandler(db, cfg))
s.AddTool(addCommentTool(), addCommentHandler(db, cfg))
s.AddTool(listAnswersTool(), listAnswersHandler(db, cfg))
s.AddTool(linkAnswerTool(), linkAnswerHandler(db))
s.AddTool(listTasksTool(), listTasksHandler(db, cfg))
return s
}
// MCPContextFunc propagates user_id from OAuth middleware into MCP tool context.
func MCPContextFunc(ctx context.Context, r interface{ Context() context.Context }) context.Context {
if uid := UserIDFromContext(r.Context()); uid != "" {
return context.WithValue(ctx, ctxUserID, uid)
}
return ctx
}
// --- Tool definitions ---
func listProjectsTool() mcp.Tool {
return mcp.NewTool("list_projects",
mcp.WithDescription("List all projects the user has access to"),
)
}
func getProjectTool() mcp.Tool {
return mcp.NewTool("get_project",
mcp.WithDescription("Get project details including workstreams and summary"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
)
}
func listRequestsTool() mcp.Tool {
return mcp.NewTool("list_requests",
mcp.WithDescription("List all requests in a project as a tree (request lists, sections, requests)"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
)
}
func getRequestTool() mcp.Tool {
return mcp.NewTool("get_request",
mcp.WithDescription("Get request details including linked answers"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
)
}
func searchRequestsTool() mcp.Tool {
return mcp.NewTool("search_requests",
mcp.WithDescription("Search requests by keyword in title, description, or comments"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
mcp.WithString("query", mcp.Description("Search keyword"), mcp.Required()),
)
}
func updateRequestStatusTool() mcp.Tool {
return mcp.NewTool("update_request_status",
mcp.WithDescription("Update the status of a request"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
mcp.WithString("status", mcp.Description("New status: open, in_process, partial, or complete"), mcp.Required()),
)
}
func addCommentTool() mcp.Tool {
return mcp.NewTool("add_comment",
mcp.WithDescription("Add a buyer or seller comment to a request"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
mcp.WithString("comment", mcp.Description("The comment text"), mcp.Required()),
mcp.WithString("side", mcp.Description("Which side: buyer or seller"), mcp.Required()),
)
}
func listAnswersTool() mcp.Tool {
return mcp.NewTool("list_answers",
mcp.WithDescription("List answer documents in a project"),
mcp.WithString("project_id", mcp.Description("The project ID"), mcp.Required()),
)
}
func linkAnswerTool() mcp.Tool {
return mcp.NewTool("link_answer",
mcp.WithDescription("Link an existing answer document to a request"),
mcp.WithString("request_id", mcp.Description("The request ID"), mcp.Required()),
mcp.WithString("answer_id", mcp.Description("The answer entry ID"), mcp.Required()),
)
}
func listTasksTool() mcp.Tool {
return mcp.NewTool("list_tasks",
mcp.WithDescription("List tasks assigned to the current user"),
mcp.WithString("project_id", mcp.Description("Optional project ID to filter tasks")),
)
}
// --- Tool handlers ---
func listProjectsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projects, err := lib.ProjectsByUser(db, cfg, userID)
if err != nil {
return mcp.NewToolResultError("Failed to list projects: " + err.Error()), nil
}
var lines []string
for _, p := range projects {
var data map[string]any
if p.DataText != "" {
_ = json.Unmarshal([]byte(p.DataText), &data)
}
name, _ := data["name"].(string)
if name == "" {
name = p.EntryID
}
lines = append(lines, fmt.Sprintf("- %s (id: %s, stage: %s)", name, p.EntryID, p.Stage))
}
if len(lines) == 0 {
return mcp.NewToolResultText("No projects found."), nil
}
return mcp.NewToolResultText("Projects:\n" + strings.Join(lines, "\n")), nil
}
}
func getProjectHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
entry, err := lib.EntryByID(db, cfg, projectID)
if err != nil {
return mcp.NewToolResultError("Project not found"), nil
}
if entry.Type != lib.TypeProject {
return mcp.NewToolResultError("Not a project"), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Project: %s\nID: %s\nStage: %s\nCreated: %d\nData: %s",
entry.SummaryText, entry.EntryID, entry.Stage, entry.CreatedAt, entry.DataText)), nil
}
}
func listRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
// Verify access
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
return mcp.NewToolResultError("Access denied"), nil
}
filter := lib.EntryFilter{ProjectID: projectID}
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
if err != nil {
return mcp.NewToolResultError("Failed to list entries: " + err.Error()), nil
}
// Filter to request types and build tree text
var lines []string
for _, e := range entries {
if e.Type != lib.TypeRequestList && e.Type != lib.TypeSection && e.Type != lib.TypeRequest {
continue
}
indent := strings.Repeat(" ", e.Depth)
var data map[string]any
if e.DataText != "" {
_ = json.Unmarshal([]byte(e.DataText), &data)
}
switch e.Type {
case lib.TypeRequestList:
name, _ := data["name"].(string)
lines = append(lines, fmt.Sprintf("%s[Request List] %s (id: %s)", indent, name, e.EntryID))
case lib.TypeSection:
name, _ := data["name"].(string)
lines = append(lines, fmt.Sprintf("%s[Section] %s (id: %s)", indent, name, e.EntryID))
case lib.TypeRequest:
title, _ := data["title"].(string)
status, _ := data["status"].(string)
priority, _ := data["priority"].(string)
lines = append(lines, fmt.Sprintf("%s- %s [%s/%s] (id: %s)", indent, title, status, priority, e.EntryID))
}
}
if len(lines) == 0 {
return mcp.NewToolResultText("No requests found in this project."), nil
}
return mcp.NewToolResultText("Request tree:\n" + strings.Join(lines, "\n")), nil
}
}
func getRequestHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
entry, err := lib.EntryByID(db, cfg, requestID)
if err != nil {
return mcp.NewToolResultError("Request not found"), nil
}
result := fmt.Sprintf("Request ID: %s\nProject: %s\nType: %s\nData: %s",
entry.EntryID, entry.ProjectID, entry.Type, entry.DataText)
// Get linked answers
links, err := lib.AnswerLinksByRequest(db, requestID)
if err == nil && len(links) > 0 {
result += "\n\nLinked Answers:"
for _, link := range links {
answerEntry, err := lib.EntryByID(db, cfg, link.AnswerID)
if err == nil {
result += fmt.Sprintf("\n - %s (id: %s)", answerEntry.SummaryText, link.AnswerID)
}
}
}
return mcp.NewToolResultText(result), nil
}
}
func searchRequestsHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
query, err := req.RequireString("query")
if err != nil {
return mcp.NewToolResultError("query is required"), nil
}
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
return mcp.NewToolResultError("Access denied"), nil
}
filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeRequest}
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
if err != nil {
return mcp.NewToolResultError("Failed to search: " + err.Error()), nil
}
queryLower := strings.ToLower(query)
var lines []string
for _, e := range entries {
if strings.Contains(strings.ToLower(e.DataText), queryLower) ||
strings.Contains(strings.ToLower(e.SummaryText), queryLower) {
var data map[string]any
if e.DataText != "" {
_ = json.Unmarshal([]byte(e.DataText), &data)
}
title, _ := data["title"].(string)
status, _ := data["status"].(string)
lines = append(lines, fmt.Sprintf("- %s [%s] (id: %s)", title, status, e.EntryID))
}
}
if len(lines) == 0 {
return mcp.NewToolResultText("No requests matching '" + query + "' found."), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Found %d results:\n%s", len(lines), strings.Join(lines, "\n"))), nil
}
}
func updateRequestStatusHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
status, err := req.RequireString("status")
if err != nil {
return mcp.NewToolResultError("status is required"), nil
}
validStatuses := map[string]bool{"open": true, "in_process": true, "partial": true, "complete": true}
if !validStatuses[status] {
return mcp.NewToolResultError("Invalid status. Must be: open, in_process, partial, or complete"), nil
}
entry, err := lib.EntryByID(db, cfg, requestID)
if err != nil {
return mcp.NewToolResultError("Request not found"), nil
}
if entry.Type != lib.TypeRequest {
return mcp.NewToolResultError("Entry is not a request"), nil
}
// Parse current data, update status
var data lib.RequestData
if entry.DataText != "" {
_ = json.Unmarshal([]byte(entry.DataText), &data)
}
data.Status = status
dataJSON, _ := json.Marshal(data)
entry.DataText = string(dataJSON)
if err := lib.EntryWrite(db, cfg, userID, entry); err != nil {
return mcp.NewToolResultError("Failed to update: " + err.Error()), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Updated request %s status to '%s'", requestID, status)), nil
}
}
func addCommentHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
comment, err := req.RequireString("comment")
if err != nil {
return mcp.NewToolResultError("comment is required"), nil
}
side, err := req.RequireString("side")
if err != nil {
return mcp.NewToolResultError("side is required"), nil
}
if side != "buyer" && side != "seller" {
return mcp.NewToolResultError("side must be 'buyer' or 'seller'"), nil
}
entry, err := lib.EntryByID(db, cfg, requestID)
if err != nil {
return mcp.NewToolResultError("Request not found"), nil
}
if entry.Type != lib.TypeRequest {
return mcp.NewToolResultError("Entry is not a request"), nil
}
var data lib.RequestData
if entry.DataText != "" {
_ = json.Unmarshal([]byte(entry.DataText), &data)
}
if side == "buyer" {
data.BuyerComment = comment
} else {
data.SellerComment = comment
}
dataJSON, _ := json.Marshal(data)
entry.DataText = string(dataJSON)
if err := lib.EntryWrite(db, cfg, userID, entry); err != nil {
return mcp.NewToolResultError("Failed to update: " + err.Error()), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Added %s comment to request %s", side, requestID)), nil
}
}
func listAnswersHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
projectID, err := req.RequireString("project_id")
if err != nil {
return mcp.NewToolResultError("project_id is required"), nil
}
if err := lib.CheckAccessRead(db, userID, projectID, ""); err != nil {
return mcp.NewToolResultError("Access denied"), nil
}
filter := lib.EntryFilter{ProjectID: projectID, Type: lib.TypeAnswer}
entries, err := lib.EntryRead(db, cfg, userID, projectID, filter)
if err != nil {
return mcp.NewToolResultError("Failed to list answers: " + err.Error()), nil
}
var lines []string
for _, e := range entries {
lines = append(lines, fmt.Sprintf("- %s (id: %s)", e.SummaryText, e.EntryID))
}
if len(lines) == 0 {
return mcp.NewToolResultText("No answer documents found."), nil
}
return mcp.NewToolResultText("Answers:\n" + strings.Join(lines, "\n")), nil
}
}
func linkAnswerHandler(db *lib.DB) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
requestID, err := req.RequireString("request_id")
if err != nil {
return mcp.NewToolResultError("request_id is required"), nil
}
answerID, err := req.RequireString("answer_id")
if err != nil {
return mcp.NewToolResultError("answer_id is required"), nil
}
if err := lib.AnswerLinkCreate(db, answerID, requestID, userID); err != nil {
return mcp.NewToolResultError("Failed to link answer: " + err.Error()), nil
}
return mcp.NewToolResultText(fmt.Sprintf("Linked answer %s to request %s", answerID, requestID)), nil
}
}
func listTasksHandler(db *lib.DB, cfg *lib.Config) server.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := UserIDFromContext(ctx)
if userID == "" {
return mcp.NewToolResultError("Not authenticated"), nil
}
tasks, err := lib.TasksByUser(db, cfg, userID)
if err != nil {
return mcp.NewToolResultError("Failed to list tasks: " + err.Error()), nil
}
projectID := req.GetString("project_id", "")
var lines []string
for _, t := range tasks {
if projectID != "" && t.ProjectID != projectID {
continue
}
lines = append(lines, fmt.Sprintf("- %s (id: %s, project: %s, stage: %s)",
t.SummaryText, t.EntryID, t.ProjectID, t.Stage))
}
if len(lines) == 0 {
return mcp.NewToolResultText("No tasks assigned to you."), nil
}
return mcp.NewToolResultText("Your tasks:\n" + strings.Join(lines, "\n")), nil
}
}

View File

@ -60,6 +60,31 @@ func AuthMiddleware(db *lib.DB, jwtSecret []byte) func(http.Handler) http.Handle
}
}
// OAuthBearerAuth validates OAuth 2.0 bearer tokens for MCP endpoints.
func OAuthBearerAuth(db *lib.DB) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
w.Header().Set("WWW-Authenticate", "Bearer")
ErrorResponse(w, http.StatusUnauthorized, "missing_token", "Bearer token required")
return
}
tokenStr := strings.TrimPrefix(auth, "Bearer ")
token, err := lib.OAuthTokenValidate(db, tokenStr)
if err != nil || token == nil {
w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"")
ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid or expired token")
return
}
ctx := context.WithValue(r.Context(), ctxUserID, token.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// LoggingMiddleware logs HTTP requests.
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

417
api/oauth.go Normal file
View File

@ -0,0 +1,417 @@
package api
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"html/template"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/mish/dealspace/lib"
)
// OAuthHandlers holds dependencies for OAuth endpoints.
type OAuthHandlers struct {
DB *lib.DB
Cfg *lib.Config
}
// NewOAuthHandlers creates OAuth handlers.
func NewOAuthHandlers(db *lib.DB, cfg *lib.Config) *OAuthHandlers {
return &OAuthHandlers{DB: db, Cfg: cfg}
}
func baseURL(r *http.Request) string {
scheme := "https"
if r.TLS == nil && !strings.HasPrefix(r.Header.Get("X-Forwarded-Proto"), "https") {
scheme = "http"
}
return scheme + "://" + r.Host
}
// Metadata serves GET /.well-known/oauth-authorization-server (RFC 8414).
func (o *OAuthHandlers) Metadata(w http.ResponseWriter, r *http.Request) {
base := baseURL(r)
JSONResponse(w, http.StatusOK, map[string]any{
"issuer": base,
"authorization_endpoint": base + "/oauth/authorize",
"token_endpoint": base + "/oauth/token",
"revocation_endpoint": base + "/oauth/revoke",
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code"},
"code_challenge_methods_supported": []string{"S256"},
"token_endpoint_auth_methods_supported": []string{"none"},
})
}
// ResourceMetadata serves GET /.well-known/oauth-protected-resource (RFC 9728).
func (o *OAuthHandlers) ResourceMetadata(w http.ResponseWriter, r *http.Request) {
base := baseURL(r)
JSONResponse(w, http.StatusOK, map[string]any{
"resource": base,
"authorization_servers": []string{base},
})
}
// Authorize handles GET /oauth/authorize — shows consent page or redirects to login.
func (o *OAuthHandlers) Authorize(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
clientID := q.Get("client_id")
redirectURI := q.Get("redirect_uri")
responseType := q.Get("response_type")
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
state := q.Get("state")
_ = q.Get("scope") // scope passed through via query string to POST handler
// Validate required params
if responseType != "code" {
oauthError(w, redirectURI, state, "unsupported_response_type", "Only response_type=code is supported")
return
}
if codeChallenge == "" || codeChallengeMethod != "S256" {
oauthError(w, redirectURI, state, "invalid_request", "PKCE with S256 is required")
return
}
// Validate client
client, err := lib.OAuthClientByID(o.DB, clientID)
if err != nil || client == nil {
oauthError(w, redirectURI, state, "invalid_client", "Unknown client_id")
return
}
// Validate redirect_uri
if !validRedirectURI(client.RedirectURIs, redirectURI) {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
return
}
// Check if user is authenticated via JWT in cookie or Authorization header
userID := o.extractUserID(r)
if userID == "" {
// Redirect to login with return URL
loginURL := "/app/login?next=" + url.QueryEscape(r.URL.RequestURI())
http.Redirect(w, r, loginURL, http.StatusFound)
return
}
// Show consent page
o.serveConsentPage(w, client.ClientName, r.URL.RequestURI())
}
// AuthorizeApprove handles POST /oauth/authorize — processes consent approval.
func (o *OAuthHandlers) AuthorizeApprove(w http.ResponseWriter, r *http.Request) {
// Parse the original authorize params from the form
if err := r.ParseForm(); err != nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid form data")
return
}
// The form includes the original query string in a hidden field
originalQuery := r.FormValue("original_query")
parsedQuery, err := url.ParseQuery(originalQuery)
if err != nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid query parameters")
return
}
clientID := parsedQuery.Get("client_id")
redirectURI := parsedQuery.Get("redirect_uri")
codeChallenge := parsedQuery.Get("code_challenge")
state := parsedQuery.Get("state")
scope := parsedQuery.Get("scope")
// Check denial
if r.FormValue("action") == "deny" {
oauthRedirect(w, r, redirectURI, state, "", "access_denied")
return
}
// Validate client again
client, err := lib.OAuthClientByID(o.DB, clientID)
if err != nil || client == nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_client", "Unknown client_id")
return
}
if !validRedirectURI(client.RedirectURIs, redirectURI) {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
return
}
// User must be authenticated
userID := o.extractUserID(r)
if userID == "" {
ErrorResponse(w, http.StatusUnauthorized, "login_required", "Authentication required")
return
}
// Generate authorization code
codeBytes := make([]byte, 32)
if _, err := rand.Read(codeBytes); err != nil {
ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to generate code")
return
}
codeStr := hex.EncodeToString(codeBytes)
oauthCode := &lib.OAuthCode{
Code: codeStr,
ClientID: clientID,
UserID: userID,
RedirectURI: redirectURI,
CodeChallenge: codeChallenge,
Scope: scope,
ExpiresAt: time.Now().Add(10 * time.Minute).UnixMilli(),
Used: false,
}
if err := lib.OAuthCodeCreate(o.DB, oauthCode); err != nil {
ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to store authorization code")
return
}
oauthRedirect(w, r, redirectURI, state, codeStr, "")
}
// Token handles POST /oauth/token — exchanges authorization code for access token.
func (o *OAuthHandlers) Token(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
tokenError(w, "invalid_request", "Invalid form data")
return
}
grantType := r.FormValue("grant_type")
if grantType != "authorization_code" {
tokenError(w, "unsupported_grant_type", "Only authorization_code is supported")
return
}
codeStr := r.FormValue("code")
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
codeVerifier := r.FormValue("code_verifier")
if codeStr == "" || clientID == "" || codeVerifier == "" {
tokenError(w, "invalid_request", "Missing required parameters")
return
}
// Consume the code (marks used, checks expiry)
code, err := lib.OAuthCodeConsume(o.DB, codeStr)
if err != nil {
tokenError(w, "invalid_grant", "Invalid or expired authorization code")
return
}
// Verify client_id matches
if code.ClientID != clientID {
tokenError(w, "invalid_grant", "Client mismatch")
return
}
// Verify redirect_uri matches
if code.RedirectURI != redirectURI {
tokenError(w, "invalid_grant", "Redirect URI mismatch")
return
}
// Verify PKCE: SHA256(code_verifier) must match code_challenge
verifierHash := sha256.Sum256([]byte(codeVerifier))
computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
if computedChallenge != code.CodeChallenge {
tokenError(w, "invalid_grant", "PKCE verification failed")
return
}
// Generate access token
tokenBytes := make([]byte, 64)
if _, err := rand.Read(tokenBytes); err != nil {
tokenError(w, "server_error", "Failed to generate token")
return
}
tokenStr := hex.EncodeToString(tokenBytes)
now := time.Now()
expiresIn := int64(24 * 60 * 60) // 24 hours in seconds
oauthToken := &lib.OAuthToken{
Token: tokenStr,
ClientID: clientID,
UserID: code.UserID,
Scope: code.Scope,
ExpiresAt: now.Add(24 * time.Hour).UnixMilli(),
Revoked: false,
CreatedAt: now.UnixMilli(),
}
if err := lib.OAuthTokenCreate(o.DB, oauthToken); err != nil {
tokenError(w, "server_error", "Failed to store token")
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(map[string]any{
"access_token": tokenStr,
"token_type": "Bearer",
"expires_in": expiresIn,
})
}
// Revoke handles POST /oauth/revoke — revokes an access token (RFC 7009).
func (o *OAuthHandlers) Revoke(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusOK)
return
}
tokenStr := r.FormValue("token")
if tokenStr != "" {
_ = lib.OAuthTokenRevoke(o.DB, tokenStr)
}
// Always return 200 per RFC 7009
w.WriteHeader(http.StatusOK)
}
// extractUserID checks JWT from Authorization header or ds_token cookie.
func (o *OAuthHandlers) extractUserID(r *http.Request) string {
// Try Authorization header first
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
claims, err := validateJWT(token, o.Cfg.JWTSecret)
if err == nil {
session, err := lib.SessionByID(o.DB, claims.SessionID)
if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() {
return claims.UserID
}
}
}
// Try cookie
cookie, err := r.Cookie("ds_token")
if err == nil && cookie.Value != "" {
claims, err := validateJWT(cookie.Value, o.Cfg.JWTSecret)
if err == nil {
session, err := lib.SessionByID(o.DB, claims.SessionID)
if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() {
return claims.UserID
}
}
}
return ""
}
func (o *OAuthHandlers) serveConsentPage(w http.ResponseWriter, appName, authorizeURI string) {
// Extract query string from the authorize URI
parsed, _ := url.Parse(authorizeURI)
queryString := parsed.RawQuery
candidates := []string{
"portal/templates/auth/consent.html",
filepath.Join("/opt/dealspace/portal/templates/auth/consent.html"),
}
var tmpl *template.Template
var err error
for _, p := range candidates {
if _, statErr := os.Stat(p); statErr == nil {
tmpl, err = template.ParseFiles(p)
if err == nil {
break
}
}
}
if tmpl == nil {
http.Error(w, "Consent template not found", http.StatusInternalServerError)
return
}
data := map[string]string{
"AppName": appName,
"OriginalQuery": queryString,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl.Execute(w, data)
}
func validRedirectURI(registered []string, uri string) bool {
parsed, err := url.Parse(uri)
if err != nil {
return false
}
for _, r := range registered {
rParsed, err := url.Parse(r)
if err != nil {
continue
}
// Match scheme + host (port-agnostic for localhost)
if parsed.Scheme == rParsed.Scheme && parsed.Hostname() == rParsed.Hostname() {
return true
}
}
return false
}
func oauthError(w http.ResponseWriter, redirectURI, state, errCode, errDesc string) {
if redirectURI == "" {
ErrorResponse(w, http.StatusBadRequest, errCode, errDesc)
return
}
oauthRedirect(w, nil, redirectURI, state, "", errCode)
}
func oauthRedirect(w http.ResponseWriter, r *http.Request, redirectURI, state, code, errCode string) {
u, err := url.Parse(redirectURI)
if err != nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
return
}
q := u.Query()
if code != "" {
q.Set("code", code)
}
if errCode != "" {
q.Set("error", errCode)
}
if state != "" {
q.Set("state", state)
}
u.RawQuery = q.Encode()
if r != nil {
http.Redirect(w, r, u.String(), http.StatusFound)
} else {
w.Header().Set("Location", u.String())
w.WriteHeader(http.StatusFound)
}
}
func tokenError(w http.ResponseWriter, errCode, errDesc string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": errCode,
"error_description": errDesc,
})
}
// SeedOAuthClient creates the default Claude OAuth client if it doesn't exist.
func SeedOAuthClient(db *lib.DB) {
client := &lib.OAuthClient{
ClientID: "claude",
ClientName: "Claude",
RedirectURIs: []string{
"http://localhost",
"http://127.0.0.1",
"http://localhost:0",
"http://127.0.0.1:0",
},
CreatedAt: time.Now().UnixMilli(),
}
_ = lib.OAuthClientCreate(db, client)
}

View File

@ -1,17 +1,20 @@
package api
import (
"context"
"io/fs"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/mark3labs/mcp-go/server"
"github.com/mish/dealspace/lib"
)
// NewRouter creates the main router with all routes registered.
func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs.FS, portalFS fs.FS) *chi.Mux {
func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs.FS, portalFS fs.FS, mcpServer *server.MCPServer) *chi.Mux {
r := chi.NewRouter()
h := NewHandlers(db, cfg, store)
oauth := NewOAuthHandlers(db, cfg)
// Global middleware
r.Use(LoggingMiddleware)
@ -103,6 +106,33 @@ func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs.
r.Post("/admin/impersonate", h.AdminImpersonate)
})
// OAuth metadata (unauthenticated, auto-discovery)
r.Get("/.well-known/oauth-authorization-server", oauth.Metadata)
r.Get("/.well-known/oauth-protected-resource", oauth.ResourceMetadata)
// OAuth endpoints
r.Get("/oauth/authorize", oauth.Authorize)
r.Post("/oauth/authorize", oauth.AuthorizeApprove)
r.Post("/oauth/token", oauth.Token)
r.Post("/oauth/revoke", oauth.Revoke)
// MCP endpoint (OAuth bearer auth)
if mcpServer != nil {
mcpHTTP := server.NewStreamableHTTPServer(mcpServer,
server.WithEndpointPath("/mcp"),
server.WithHTTPContextFunc(func(ctx context.Context, r *http.Request) context.Context {
if uid := UserIDFromContext(r.Context()); uid != "" {
return context.WithValue(ctx, ctxUserID, uid)
}
return ctx
}),
)
r.Group(func(r chi.Router) {
r.Use(OAuthBearerAuth(db))
r.Handle("/mcp", mcpHTTP)
})
}
// Portal app routes (serve templates, auth checked client-side via JS)
r.Get("/app", func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/app/projects", http.StatusFound)

View File

@ -40,6 +40,9 @@ func main() {
// Always seed super admin accounts on startup
seedSuperAdmins(db)
// Seed OAuth client for Claude
api.SeedOAuthClient(db)
// Seed demo data if SEED_DEMO=true
if os.Getenv("SEED_DEMO") == "true" {
seedDemoData(db, cfg)
@ -51,7 +54,10 @@ func main() {
log.Fatalf("embed website: %v", err)
}
router := api.NewRouter(db, cfg, store, websiteFS, nil)
// MCP server for Claude integration
mcpServer := api.NewMCPServer(db, cfg)
router := api.NewRouter(db, cfg, store, websiteFS, nil, mcpServer)
addr := ":" + cfg.Port
log.Printf("dealspace starting on %s (env=%s)", addr, cfg.Env)

11
go.mod
View File

@ -6,26 +6,35 @@ require (
github.com/go-chi/chi/v5 v5.2.1
github.com/google/uuid v1.6.0
github.com/klauspost/compress v1.18.0
github.com/mark3labs/mcp-go v0.44.1
github.com/mattn/go-sqlite3 v1.14.24
github.com/pdfcpu/pdfcpu v0.11.1
github.com/xuri/excelize/v2 v2.10.1
golang.org/x/crypto v0.48.0
)
require (
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.2.0 // indirect
github.com/hhrutter/lzw v1.0.0 // indirect
github.com/hhrutter/pkcs7 v0.2.0 // indirect
github.com/hhrutter/tiff v1.0.2 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/richardlehane/mscfb v1.0.6 // indirect
github.com/richardlehane/msoleps v1.0.6 // indirect
github.com/spf13/cast v1.7.1 // indirect
github.com/tiendc/go-deepcopy v1.7.2 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xuri/efp v0.0.1 // indirect
github.com/xuri/excelize/v2 v2.10.1 // indirect
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
golang.org/x/image v0.32.0 // indirect
golang.org/x/net v0.50.0 // indirect
golang.org/x/text v0.34.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

39
go.sum
View File

@ -1,7 +1,17 @@
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs=
github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0=
github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY=
github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8=
github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hhrutter/lzw v1.0.0 h1:laL89Llp86W3rRs83LvKbwYRx6INE8gDn0XNb1oXtm0=
@ -10,8 +20,19 @@ github.com/hhrutter/pkcs7 v0.2.0 h1:i4HN2XMbGQpZRnKBLsUwO3dSckzgX142TNqY/KfXg+I=
github.com/hhrutter/pkcs7 v0.2.0/go.mod h1:aEzKz0+ZAlz7YaEMY47jDHL14hVWD6iXt0AgqgAvWgE=
github.com/hhrutter/tiff v1.0.2 h1:7H3FQQpKu/i5WaSChoD1nnJbGx4MxU5TlNqqpxw55z8=
github.com/hhrutter/tiff v1.0.2/go.mod h1:pcOeuK5loFUE7Y/WnzGw20YxUdnqjY1P0Jlcieb/cCw=
github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E=
github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mark3labs/mcp-go v0.44.1 h1:2PKppYlT9X2fXnE8SNYQLAX4hNjfPB0oNLqQVcN6mE8=
github.com/mark3labs/mcp-go v0.44.1/go.mod h1:YnJfOL382MIWDx1kMY+2zsRHU/q78dBg9aFb8W6Thdw=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
@ -20,31 +41,41 @@ github.com/pdfcpu/pdfcpu v0.11.1 h1:htHBSkGH5jMKWC6e0sihBFbcKZ8vG1M67c8/dJxhjas=
github.com/pdfcpu/pdfcpu v0.11.1/go.mod h1:pP3aGga7pRvwFWAm9WwFvo+V68DfANi9kxSQYioNYcw=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/richardlehane/mscfb v1.0.6 h1:eN3bvvZCp00bs7Zf52bxNwAx5lJDBK1tCuH19qq5aC8=
github.com/richardlehane/mscfb v1.0.6/go.mod h1:pe0+IUIc0AHh0+teNzBlJCtSyZdFOGgV4ZK9bsoV+Jo=
github.com/richardlehane/msoleps v1.0.6 h1:9BvkpjvD+iUBalUY4esMwv6uBkfOip/Lzvd93jvR9gg=
github.com/richardlehane/msoleps v1.0.6/go.mod h1:BWev5JBpU9Ko2WAgmZEuiz4/u3ZYTKbjLycmwiWUfWg=
github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8=
github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs=
github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y=
github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tiendc/go-deepcopy v1.7.2 h1:Ut2yYR7W9tWjTQitganoIue4UGxZwCcJy3orjrrIj44=
github.com/tiendc/go-deepcopy v1.7.2/go.mod h1:4bKjNC2r7boYOkD2IOuZpYjmlDdzjbpTRyCx+goBCJQ=
github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc=
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/xuri/efp v0.0.1 h1:fws5Rv3myXyYni8uwj2qKjVaRP30PdjeYe2Y6FDsCL8=
github.com/xuri/efp v0.0.1/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI=
github.com/xuri/excelize/v2 v2.10.1 h1:V62UlqopMqha3kOpnlHy2CcRVw1V8E63jFoWUmMzxN0=
github.com/xuri/excelize/v2 v2.10.1/go.mod h1:iG5tARpgaEeIhTqt3/fgXCGoBRt4hNXgCp3tfXKoOIc=
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9 h1:+C0TIdyyYmzadGaL/HBLbf3WdLgC29pgyhTjAT/0nuE=
github.com/xuri/nfp v0.0.2-0.20250530014748-2ddeb826f9a9/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ=
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/image v0.32.0 h1:6lZQWq75h7L5IWNk0r+SCpUJ6tUVd3v4ZHnbRKLkUDQ=
golang.org/x/image v0.32.0/go.mod h1:/R37rrQmKXtO6tYXAjtDLwQgFLHmhW+V6ayXlxzP2Pc=
golang.org/x/net v0.50.0 h1:ucWh9eiCGyDR3vtzso0WMQinm2Dnt8cFMuQa9K33J60=
golang.org/x/net v0.50.0/go.mod h1:UgoSli3F/pBgdJBHCTc+tp3gmrU4XswgGRgtnwWTfyM=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk=
golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -3,6 +3,7 @@ package lib
import (
"crypto/subtle"
"database/sql"
"encoding/json"
"errors"
"fmt"
"os"
@ -1003,3 +1004,99 @@ func SoftDeleteTree(db *DB, entryID, actorID string) error {
}
return nil
}
// --- OAuth ---
// OAuthClientByID looks up an OAuth client by its client_id.
func OAuthClientByID(db *DB, clientID string) (*OAuthClient, error) {
row := db.Conn.QueryRow(`SELECT client_id, client_name, redirect_uris, created_at FROM oauth_clients WHERE client_id = ?`, clientID)
var c OAuthClient
var uris string
if err := row.Scan(&c.ClientID, &c.ClientName, &uris, &c.CreatedAt); err != nil {
return nil, err
}
_ = json.Unmarshal([]byte(uris), &c.RedirectURIs)
return &c, nil
}
// OAuthClientCreate inserts a new OAuth client (skips if already exists).
func OAuthClientCreate(db *DB, c *OAuthClient) error {
uris, _ := json.Marshal(c.RedirectURIs)
_, err := db.Conn.Exec(
`INSERT OR IGNORE INTO oauth_clients (client_id, client_name, redirect_uris, created_at) VALUES (?, ?, ?, ?)`,
c.ClientID, c.ClientName, string(uris), c.CreatedAt,
)
return err
}
// OAuthCodeCreate stores an authorization code.
func OAuthCodeCreate(db *DB, code *OAuthCode) error {
used := 0
if code.Used {
used = 1
}
_, err := db.Conn.Exec(
`INSERT INTO oauth_codes (code, client_id, user_id, redirect_uri, code_challenge, scope, expires_at, used) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
code.Code, code.ClientID, code.UserID, code.RedirectURI, code.CodeChallenge, code.Scope, code.ExpiresAt, used,
)
return err
}
// OAuthCodeConsume looks up and marks an authorization code as used. Returns nil if invalid/expired/used.
func OAuthCodeConsume(db *DB, codeStr string) (*OAuthCode, error) {
row := db.Conn.QueryRow(
`SELECT code, client_id, user_id, redirect_uri, code_challenge, scope, expires_at, used FROM oauth_codes WHERE code = ?`,
codeStr,
)
var c OAuthCode
var used int
if err := row.Scan(&c.Code, &c.ClientID, &c.UserID, &c.RedirectURI, &c.CodeChallenge, &c.Scope, &c.ExpiresAt, &used); err != nil {
return nil, err
}
if used == 1 || c.ExpiresAt < time.Now().UnixMilli() {
return nil, ErrAccessDenied
}
// Mark as used
_, err := db.Conn.Exec(`UPDATE oauth_codes SET used = 1 WHERE code = ?`, codeStr)
if err != nil {
return nil, err
}
return &c, nil
}
// OAuthTokenCreate stores an access token.
func OAuthTokenCreate(db *DB, t *OAuthToken) error {
revoked := 0
if t.Revoked {
revoked = 1
}
_, err := db.Conn.Exec(
`INSERT INTO oauth_tokens (token, client_id, user_id, scope, expires_at, revoked, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)`,
t.Token, t.ClientID, t.UserID, t.Scope, t.ExpiresAt, revoked, t.CreatedAt,
)
return err
}
// OAuthTokenValidate returns a token if it's valid (not expired, not revoked).
func OAuthTokenValidate(db *DB, tokenStr string) (*OAuthToken, error) {
row := db.Conn.QueryRow(
`SELECT token, client_id, user_id, scope, expires_at, revoked, created_at FROM oauth_tokens WHERE token = ?`,
tokenStr,
)
var t OAuthToken
var revoked int
if err := row.Scan(&t.Token, &t.ClientID, &t.UserID, &t.Scope, &t.ExpiresAt, &revoked, &t.CreatedAt); err != nil {
return nil, err
}
t.Revoked = revoked == 1
if t.Revoked || t.ExpiresAt < time.Now().UnixMilli() {
return nil, ErrAccessDenied
}
return &t, nil
}
// OAuthTokenRevoke marks a token as revoked.
func OAuthTokenRevoke(db *DB, tokenStr string) error {
_, err := db.Conn.Exec(`UPDATE oauth_tokens SET revoked = 1 WHERE token = ?`, tokenStr)
return err
}

View File

@ -270,3 +270,34 @@ type WorkstreamData struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
}
// OAuthClient represents a registered OAuth 2.0 client application.
type OAuthClient struct {
ClientID string `json:"client_id"`
ClientName string `json:"client_name"`
RedirectURIs []string `json:"redirect_uris"`
CreatedAt int64 `json:"created_at"`
}
// OAuthCode represents an OAuth 2.0 authorization code.
type OAuthCode struct {
Code string `json:"code"`
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
RedirectURI string `json:"redirect_uri"`
CodeChallenge string `json:"code_challenge"`
Scope string `json:"scope"`
ExpiresAt int64 `json:"expires_at"`
Used bool `json:"used"`
}
// OAuthToken represents an OAuth 2.0 access token.
type OAuthToken struct {
Token string `json:"token"`
ClientID string `json:"client_id"`
UserID string `json:"user_id"`
Scope string `json:"scope"`
ExpiresAt int64 `json:"expires_at"`
Revoked bool `json:"revoked"`
CreatedAt int64 `json:"created_at"`
}

27
migrations/003_oauth.sql Normal file
View File

@ -0,0 +1,27 @@
CREATE TABLE IF NOT EXISTS oauth_clients (
client_id TEXT PRIMARY KEY,
client_name TEXT NOT NULL,
redirect_uris TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS oauth_codes (
code TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
user_id TEXT NOT NULL,
redirect_uri TEXT NOT NULL,
code_challenge TEXT NOT NULL,
scope TEXT NOT NULL DEFAULT '',
expires_at INTEGER NOT NULL,
used INTEGER NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS oauth_tokens (
token TEXT PRIMARY KEY,
client_id TEXT NOT NULL,
user_id TEXT NOT NULL,
scope TEXT NOT NULL DEFAULT '',
expires_at INTEGER NOT NULL,
revoked INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL
);

View File

@ -0,0 +1,58 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Authorize Application — Dealspace</title>
<script src="https://cdn.tailwindcss.com"></script>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap" rel="stylesheet">
<style>body { font-family: 'Inter', sans-serif; }</style>
</head>
<body class="bg-gray-50 min-h-screen flex items-center justify-center">
<div class="bg-white rounded-xl shadow-lg p-8 w-full max-w-md">
<div class="text-center mb-6">
<div class="w-16 h-16 bg-indigo-100 rounded-full flex items-center justify-center mx-auto mb-4">
<svg class="w-8 h-8 text-indigo-600" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M9 12l2 2 4-4m5.618-4.016A11.955 11.955 0 0112 2.944a11.955 11.955 0 01-8.618 3.04A12.02 12.02 0 003 9c0 5.591 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.042-.133-2.052-.382-3.016z"/>
</svg>
</div>
<h1 class="text-2xl font-bold text-gray-900">Authorize {{.AppName}}</h1>
<p class="text-gray-500 mt-2">This application is requesting access to your Dealspace account.</p>
</div>
<div class="bg-gray-50 rounded-lg p-4 mb-6">
<h3 class="text-sm font-semibold text-gray-700 mb-2">This will allow {{.AppName}} to:</h3>
<ul class="space-y-2 text-sm text-gray-600">
<li class="flex items-center gap-2">
<svg class="w-4 h-4 text-green-500 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/></svg>
View your projects and requests
</li>
<li class="flex items-center gap-2">
<svg class="w-4 h-4 text-green-500 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/></svg>
Update request statuses and comments
</li>
<li class="flex items-center gap-2">
<svg class="w-4 h-4 text-green-500 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/></svg>
Link answers to requests
</li>
<li class="flex items-center gap-2">
<svg class="w-4 h-4 text-green-500 flex-shrink-0" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M16.707 5.293a1 1 0 010 1.414l-8 8a1 1 0 01-1.414 0l-4-4a1 1 0 011.414-1.414L8 12.586l7.293-7.293a1 1 0 011.414 0z" clip-rule="evenodd"/></svg>
View your tasks
</li>
</ul>
</div>
<form method="POST" action="/oauth/authorize" class="space-y-3">
<input type="hidden" name="original_query" value="{{.OriginalQuery}}">
<button type="submit" name="action" value="allow" class="w-full bg-indigo-600 text-white py-2.5 px-4 rounded-lg font-medium hover:bg-indigo-700 transition-colors">
Allow Access
</button>
<button type="submit" name="action" value="deny" class="w-full bg-white text-gray-700 py-2.5 px-4 rounded-lg font-medium border border-gray-300 hover:bg-gray-50 transition-colors">
Deny
</button>
</form>
<p class="text-xs text-gray-400 text-center mt-4">You can revoke access at any time from your account settings.</p>
</div>
</body>
</html>

View File

@ -85,9 +85,15 @@
</div>
<script>
// Parse ?next= redirect URL
const params = new URLSearchParams(window.location.search);
const nextURL = params.get('next') || '/app/tasks';
// If already logged in, redirect
if (localStorage.getItem('ds_token')) {
window.location.href = '/app/tasks';
// Set cookie too (for server-side auth like OAuth consent)
document.cookie = 'ds_token=' + localStorage.getItem('ds_token') + '; path=/; SameSite=Lax; max-age=3600';
window.location.href = nextURL;
}
let currentEmail = '';
@ -152,8 +158,11 @@
localStorage.setItem('ds_token', data.token);
localStorage.setItem('ds_user', JSON.stringify(data.user));
// Everyone lands on /app/tasks — admin panel accessible from nav
window.location.href = '/app/tasks';
// Set cookie for server-side auth (OAuth consent flow)
document.cookie = 'ds_token=' + data.token + '; path=/; SameSite=Lax; max-age=3600';
// Redirect to next URL or default to tasks
window.location.href = nextURL;
} catch (err) {
errorEl.textContent = err.message;
errorEl.classList.remove('hidden');