435 lines
12 KiB
Go
435 lines
12 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"html/template"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/mish/dealspace/lib"
|
|
)
|
|
|
|
// OAuthHandlers holds dependencies for OAuth endpoints.
|
|
type OAuthHandlers struct {
|
|
DB *lib.DB
|
|
Cfg *lib.Config
|
|
}
|
|
|
|
// NewOAuthHandlers creates OAuth handlers.
|
|
func NewOAuthHandlers(db *lib.DB, cfg *lib.Config) *OAuthHandlers {
|
|
return &OAuthHandlers{DB: db, Cfg: cfg}
|
|
}
|
|
|
|
func baseURL(r *http.Request) string {
|
|
scheme := "https"
|
|
if r.TLS == nil && !strings.HasPrefix(r.Header.Get("X-Forwarded-Proto"), "https") {
|
|
scheme = "http"
|
|
}
|
|
return scheme + "://" + r.Host
|
|
}
|
|
|
|
// Metadata serves GET /.well-known/oauth-authorization-server (RFC 8414).
|
|
func (o *OAuthHandlers) Metadata(w http.ResponseWriter, r *http.Request) {
|
|
base := baseURL(r)
|
|
JSONResponse(w, http.StatusOK, map[string]any{
|
|
"issuer": base,
|
|
"authorization_endpoint": base + "/oauth/authorize",
|
|
"token_endpoint": base + "/oauth/token",
|
|
"revocation_endpoint": base + "/oauth/revoke",
|
|
"response_types_supported": []string{"code"},
|
|
"grant_types_supported": []string{"authorization_code"},
|
|
"code_challenge_methods_supported": []string{"S256"},
|
|
"token_endpoint_auth_methods_supported": []string{"none"},
|
|
})
|
|
}
|
|
|
|
// ResourceMetadata serves GET /.well-known/oauth-protected-resource (RFC 9728).
|
|
func (o *OAuthHandlers) ResourceMetadata(w http.ResponseWriter, r *http.Request) {
|
|
base := baseURL(r)
|
|
JSONResponse(w, http.StatusOK, map[string]any{
|
|
"resource": base,
|
|
"authorization_servers": []string{base},
|
|
})
|
|
}
|
|
|
|
// Authorize handles GET /oauth/authorize — shows consent page or redirects to login.
|
|
func (o *OAuthHandlers) Authorize(w http.ResponseWriter, r *http.Request) {
|
|
q := r.URL.Query()
|
|
clientID := q.Get("client_id")
|
|
redirectURI := q.Get("redirect_uri")
|
|
responseType := q.Get("response_type")
|
|
codeChallenge := q.Get("code_challenge")
|
|
codeChallengeMethod := q.Get("code_challenge_method")
|
|
state := q.Get("state")
|
|
_ = q.Get("scope") // scope passed through via query string to POST handler
|
|
|
|
// Validate required params
|
|
if responseType != "code" {
|
|
oauthError(w, redirectURI, state, "unsupported_response_type", "Only response_type=code is supported")
|
|
return
|
|
}
|
|
if codeChallenge == "" || codeChallengeMethod != "S256" {
|
|
oauthError(w, redirectURI, state, "invalid_request", "PKCE with S256 is required")
|
|
return
|
|
}
|
|
|
|
// Validate client
|
|
client, err := lib.OAuthClientByID(o.DB, clientID)
|
|
if err != nil || client == nil {
|
|
oauthError(w, redirectURI, state, "invalid_client", "Unknown client_id")
|
|
return
|
|
}
|
|
|
|
// Validate redirect_uri
|
|
if !validRedirectURI(client.RedirectURIs, redirectURI) {
|
|
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
|
|
return
|
|
}
|
|
|
|
// Check if user is authenticated via JWT in cookie or Authorization header
|
|
userID := o.extractUserID(r)
|
|
if userID == "" {
|
|
// Redirect to login with return URL
|
|
loginURL := "/app/login?next=" + url.QueryEscape(r.URL.RequestURI())
|
|
http.Redirect(w, r, loginURL, http.StatusFound)
|
|
return
|
|
}
|
|
|
|
// Show consent page
|
|
o.serveConsentPage(w, client.ClientName, r.URL.RequestURI())
|
|
}
|
|
|
|
// AuthorizeApprove handles POST /oauth/authorize — processes consent approval.
|
|
func (o *OAuthHandlers) AuthorizeApprove(w http.ResponseWriter, r *http.Request) {
|
|
// Parse the original authorize params from the form
|
|
if err := r.ParseForm(); err != nil {
|
|
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid form data")
|
|
return
|
|
}
|
|
|
|
// The form includes the original query string in a hidden field
|
|
originalQuery := r.FormValue("original_query")
|
|
parsedQuery, err := url.ParseQuery(originalQuery)
|
|
if err != nil {
|
|
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid query parameters")
|
|
return
|
|
}
|
|
|
|
clientID := parsedQuery.Get("client_id")
|
|
redirectURI := parsedQuery.Get("redirect_uri")
|
|
codeChallenge := parsedQuery.Get("code_challenge")
|
|
state := parsedQuery.Get("state")
|
|
scope := parsedQuery.Get("scope")
|
|
|
|
// Check denial
|
|
if r.FormValue("action") == "deny" {
|
|
oauthRedirect(w, r, redirectURI, state, "", "access_denied")
|
|
return
|
|
}
|
|
|
|
// Validate client again
|
|
client, err := lib.OAuthClientByID(o.DB, clientID)
|
|
if err != nil || client == nil {
|
|
ErrorResponse(w, http.StatusBadRequest, "invalid_client", "Unknown client_id")
|
|
return
|
|
}
|
|
if !validRedirectURI(client.RedirectURIs, redirectURI) {
|
|
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
|
|
return
|
|
}
|
|
|
|
// User must be authenticated
|
|
userID := o.extractUserID(r)
|
|
if userID == "" {
|
|
ErrorResponse(w, http.StatusUnauthorized, "login_required", "Authentication required")
|
|
return
|
|
}
|
|
|
|
// Generate authorization code
|
|
codeBytes := make([]byte, 32)
|
|
if _, err := rand.Read(codeBytes); err != nil {
|
|
ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to generate code")
|
|
return
|
|
}
|
|
codeStr := hex.EncodeToString(codeBytes)
|
|
|
|
oauthCode := &lib.OAuthCode{
|
|
Code: codeStr,
|
|
ClientID: clientID,
|
|
UserID: userID,
|
|
RedirectURI: redirectURI,
|
|
CodeChallenge: codeChallenge,
|
|
Scope: scope,
|
|
ExpiresAt: time.Now().Add(10 * time.Minute).UnixMilli(),
|
|
Used: false,
|
|
}
|
|
if err := lib.OAuthCodeCreate(o.DB, oauthCode); err != nil {
|
|
ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to store authorization code")
|
|
return
|
|
}
|
|
|
|
oauthRedirect(w, r, redirectURI, state, codeStr, "")
|
|
}
|
|
|
|
// Token handles POST /oauth/token — exchanges authorization code for access token.
|
|
func (o *OAuthHandlers) Token(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
tokenError(w, "invalid_request", "Invalid form data")
|
|
return
|
|
}
|
|
|
|
grantType := r.FormValue("grant_type")
|
|
if grantType != "authorization_code" {
|
|
tokenError(w, "unsupported_grant_type", "Only authorization_code is supported")
|
|
return
|
|
}
|
|
|
|
codeStr := r.FormValue("code")
|
|
redirectURI := r.FormValue("redirect_uri")
|
|
clientID := r.FormValue("client_id")
|
|
codeVerifier := r.FormValue("code_verifier")
|
|
|
|
if codeStr == "" || clientID == "" || codeVerifier == "" {
|
|
tokenError(w, "invalid_request", "Missing required parameters")
|
|
return
|
|
}
|
|
|
|
// Consume the code (marks used, checks expiry)
|
|
code, err := lib.OAuthCodeConsume(o.DB, codeStr)
|
|
if err != nil {
|
|
tokenError(w, "invalid_grant", "Invalid or expired authorization code")
|
|
return
|
|
}
|
|
|
|
// Verify client_id matches
|
|
if code.ClientID != clientID {
|
|
tokenError(w, "invalid_grant", "Client mismatch")
|
|
return
|
|
}
|
|
|
|
// Verify redirect_uri matches
|
|
if code.RedirectURI != redirectURI {
|
|
tokenError(w, "invalid_grant", "Redirect URI mismatch")
|
|
return
|
|
}
|
|
|
|
// Verify PKCE: SHA256(code_verifier) must match code_challenge
|
|
verifierHash := sha256.Sum256([]byte(codeVerifier))
|
|
computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
|
|
if computedChallenge != code.CodeChallenge {
|
|
tokenError(w, "invalid_grant", "PKCE verification failed")
|
|
return
|
|
}
|
|
|
|
// Generate access token
|
|
tokenBytes := make([]byte, 64)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
tokenError(w, "server_error", "Failed to generate token")
|
|
return
|
|
}
|
|
tokenStr := hex.EncodeToString(tokenBytes)
|
|
now := time.Now()
|
|
expiresIn := int64(24 * 60 * 60) // 24 hours in seconds
|
|
|
|
oauthToken := &lib.OAuthToken{
|
|
Token: tokenStr,
|
|
ClientID: clientID,
|
|
UserID: code.UserID,
|
|
Scope: code.Scope,
|
|
ExpiresAt: now.Add(24 * time.Hour).UnixMilli(),
|
|
Revoked: false,
|
|
CreatedAt: now.UnixMilli(),
|
|
}
|
|
if err := lib.OAuthTokenCreate(o.DB, oauthToken); err != nil {
|
|
tokenError(w, "server_error", "Failed to store token")
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
json.NewEncoder(w).Encode(map[string]any{
|
|
"access_token": tokenStr,
|
|
"token_type": "Bearer",
|
|
"expires_in": expiresIn,
|
|
})
|
|
}
|
|
|
|
// Revoke handles POST /oauth/revoke — revokes an access token (RFC 7009).
|
|
func (o *OAuthHandlers) Revoke(w http.ResponseWriter, r *http.Request) {
|
|
if err := r.ParseForm(); err != nil {
|
|
w.WriteHeader(http.StatusOK)
|
|
return
|
|
}
|
|
tokenStr := r.FormValue("token")
|
|
if tokenStr != "" {
|
|
_ = lib.OAuthTokenRevoke(o.DB, tokenStr)
|
|
}
|
|
// Always return 200 per RFC 7009
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
// extractUserID checks JWT from Authorization header or ds_token cookie.
|
|
func (o *OAuthHandlers) extractUserID(r *http.Request) string {
|
|
// Try Authorization header first
|
|
auth := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(auth, "Bearer ") {
|
|
token := strings.TrimPrefix(auth, "Bearer ")
|
|
claims, err := validateJWT(token, o.Cfg.JWTSecret)
|
|
if err == nil {
|
|
session, err := lib.SessionByID(o.DB, claims.SessionID)
|
|
if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() {
|
|
return claims.UserID
|
|
}
|
|
}
|
|
}
|
|
|
|
// Try cookie
|
|
cookie, err := r.Cookie("ds_token")
|
|
if err == nil && cookie.Value != "" {
|
|
claims, err := validateJWT(cookie.Value, o.Cfg.JWTSecret)
|
|
if err == nil {
|
|
session, err := lib.SessionByID(o.DB, claims.SessionID)
|
|
if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() {
|
|
return claims.UserID
|
|
}
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (o *OAuthHandlers) serveConsentPage(w http.ResponseWriter, appName, authorizeURI string) {
|
|
// Extract query string from the authorize URI
|
|
parsed, _ := url.Parse(authorizeURI)
|
|
queryString := parsed.RawQuery
|
|
|
|
layoutCandidates := []string{
|
|
"portal/templates/layouts/auth.html",
|
|
"/opt/dealspace/portal/templates/layouts/auth.html",
|
|
}
|
|
pageCandidates := []string{
|
|
"portal/templates/auth/consent.html",
|
|
"/opt/dealspace/portal/templates/auth/consent.html",
|
|
}
|
|
|
|
var layoutPath, pagePath string
|
|
for _, p := range layoutCandidates {
|
|
if _, err := os.Stat(p); err == nil {
|
|
layoutPath = p
|
|
break
|
|
}
|
|
}
|
|
for _, p := range pageCandidates {
|
|
if _, err := os.Stat(p); err == nil {
|
|
pagePath = p
|
|
break
|
|
}
|
|
}
|
|
|
|
if layoutPath == "" || pagePath == "" {
|
|
http.Error(w, "Consent template not found", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
tmpl, err := template.ParseFiles(layoutPath, pagePath)
|
|
if err != nil {
|
|
http.Error(w, "Template parse error: "+err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
data := struct {
|
|
Title string
|
|
AppName string
|
|
OriginalQuery string
|
|
}{
|
|
Title: "Authorize Application — Dealspace",
|
|
AppName: appName,
|
|
OriginalQuery: queryString,
|
|
}
|
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
tmpl.ExecuteTemplate(w, "layout", data)
|
|
}
|
|
|
|
func validRedirectURI(registered []string, uri string) bool {
|
|
parsed, err := url.Parse(uri)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for _, r := range registered {
|
|
rParsed, err := url.Parse(r)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
// Match scheme + host (port-agnostic for localhost)
|
|
if parsed.Scheme == rParsed.Scheme && parsed.Hostname() == rParsed.Hostname() {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func oauthError(w http.ResponseWriter, redirectURI, state, errCode, errDesc string) {
|
|
if redirectURI == "" {
|
|
ErrorResponse(w, http.StatusBadRequest, errCode, errDesc)
|
|
return
|
|
}
|
|
oauthRedirect(w, nil, redirectURI, state, "", errCode)
|
|
}
|
|
|
|
func oauthRedirect(w http.ResponseWriter, r *http.Request, redirectURI, state, code, errCode string) {
|
|
u, err := url.Parse(redirectURI)
|
|
if err != nil {
|
|
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
|
|
return
|
|
}
|
|
q := u.Query()
|
|
if code != "" {
|
|
q.Set("code", code)
|
|
}
|
|
if errCode != "" {
|
|
q.Set("error", errCode)
|
|
}
|
|
if state != "" {
|
|
q.Set("state", state)
|
|
}
|
|
u.RawQuery = q.Encode()
|
|
if r != nil {
|
|
http.Redirect(w, r, u.String(), http.StatusFound)
|
|
} else {
|
|
w.Header().Set("Location", u.String())
|
|
w.WriteHeader(http.StatusFound)
|
|
}
|
|
}
|
|
|
|
func tokenError(w http.ResponseWriter, errCode, errDesc string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": errCode,
|
|
"error_description": errDesc,
|
|
})
|
|
}
|
|
|
|
// SeedOAuthClient creates the default Claude OAuth client if it doesn't exist.
|
|
func SeedOAuthClient(db *lib.DB) {
|
|
client := &lib.OAuthClient{
|
|
ClientID: "claude",
|
|
ClientName: "Claude",
|
|
RedirectURIs: []string{
|
|
"http://localhost",
|
|
"http://127.0.0.1",
|
|
"http://localhost:0",
|
|
"http://127.0.0.1:0",
|
|
},
|
|
CreatedAt: time.Now().UnixMilli(),
|
|
}
|
|
_ = lib.OAuthClientCreate(db, client)
|
|
}
|