dealspace/api/oauth.go

435 lines
12 KiB
Go

package api
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"html/template"
"net/http"
"net/url"
"os"
"strings"
"time"
"github.com/mish/dealspace/lib"
)
// OAuthHandlers holds dependencies for OAuth endpoints.
type OAuthHandlers struct {
DB *lib.DB
Cfg *lib.Config
}
// NewOAuthHandlers creates OAuth handlers.
func NewOAuthHandlers(db *lib.DB, cfg *lib.Config) *OAuthHandlers {
return &OAuthHandlers{DB: db, Cfg: cfg}
}
func baseURL(r *http.Request) string {
scheme := "https"
if r.TLS == nil && !strings.HasPrefix(r.Header.Get("X-Forwarded-Proto"), "https") {
scheme = "http"
}
return scheme + "://" + r.Host
}
// Metadata serves GET /.well-known/oauth-authorization-server (RFC 8414).
func (o *OAuthHandlers) Metadata(w http.ResponseWriter, r *http.Request) {
base := baseURL(r)
JSONResponse(w, http.StatusOK, map[string]any{
"issuer": base,
"authorization_endpoint": base + "/oauth/authorize",
"token_endpoint": base + "/oauth/token",
"revocation_endpoint": base + "/oauth/revoke",
"response_types_supported": []string{"code"},
"grant_types_supported": []string{"authorization_code"},
"code_challenge_methods_supported": []string{"S256"},
"token_endpoint_auth_methods_supported": []string{"none"},
})
}
// ResourceMetadata serves GET /.well-known/oauth-protected-resource (RFC 9728).
func (o *OAuthHandlers) ResourceMetadata(w http.ResponseWriter, r *http.Request) {
base := baseURL(r)
JSONResponse(w, http.StatusOK, map[string]any{
"resource": base,
"authorization_servers": []string{base},
})
}
// Authorize handles GET /oauth/authorize — shows consent page or redirects to login.
func (o *OAuthHandlers) Authorize(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
clientID := q.Get("client_id")
redirectURI := q.Get("redirect_uri")
responseType := q.Get("response_type")
codeChallenge := q.Get("code_challenge")
codeChallengeMethod := q.Get("code_challenge_method")
state := q.Get("state")
_ = q.Get("scope") // scope passed through via query string to POST handler
// Validate required params
if responseType != "code" {
oauthError(w, redirectURI, state, "unsupported_response_type", "Only response_type=code is supported")
return
}
if codeChallenge == "" || codeChallengeMethod != "S256" {
oauthError(w, redirectURI, state, "invalid_request", "PKCE with S256 is required")
return
}
// Validate client
client, err := lib.OAuthClientByID(o.DB, clientID)
if err != nil || client == nil {
oauthError(w, redirectURI, state, "invalid_client", "Unknown client_id")
return
}
// Validate redirect_uri
if !validRedirectURI(client.RedirectURIs, redirectURI) {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
return
}
// Check if user is authenticated via JWT in cookie or Authorization header
userID := o.extractUserID(r)
if userID == "" {
// Redirect to login with return URL
loginURL := "/app/login?next=" + url.QueryEscape(r.URL.RequestURI())
http.Redirect(w, r, loginURL, http.StatusFound)
return
}
// Show consent page
o.serveConsentPage(w, client.ClientName, r.URL.RequestURI())
}
// AuthorizeApprove handles POST /oauth/authorize — processes consent approval.
func (o *OAuthHandlers) AuthorizeApprove(w http.ResponseWriter, r *http.Request) {
// Parse the original authorize params from the form
if err := r.ParseForm(); err != nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid form data")
return
}
// The form includes the original query string in a hidden field
originalQuery := r.FormValue("original_query")
parsedQuery, err := url.ParseQuery(originalQuery)
if err != nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid query parameters")
return
}
clientID := parsedQuery.Get("client_id")
redirectURI := parsedQuery.Get("redirect_uri")
codeChallenge := parsedQuery.Get("code_challenge")
state := parsedQuery.Get("state")
scope := parsedQuery.Get("scope")
// Check denial
if r.FormValue("action") == "deny" {
oauthRedirect(w, r, redirectURI, state, "", "access_denied")
return
}
// Validate client again
client, err := lib.OAuthClientByID(o.DB, clientID)
if err != nil || client == nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_client", "Unknown client_id")
return
}
if !validRedirectURI(client.RedirectURIs, redirectURI) {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
return
}
// User must be authenticated
userID := o.extractUserID(r)
if userID == "" {
ErrorResponse(w, http.StatusUnauthorized, "login_required", "Authentication required")
return
}
// Generate authorization code
codeBytes := make([]byte, 32)
if _, err := rand.Read(codeBytes); err != nil {
ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to generate code")
return
}
codeStr := hex.EncodeToString(codeBytes)
oauthCode := &lib.OAuthCode{
Code: codeStr,
ClientID: clientID,
UserID: userID,
RedirectURI: redirectURI,
CodeChallenge: codeChallenge,
Scope: scope,
ExpiresAt: time.Now().Add(10 * time.Minute).UnixMilli(),
Used: false,
}
if err := lib.OAuthCodeCreate(o.DB, oauthCode); err != nil {
ErrorResponse(w, http.StatusInternalServerError, "internal", "Failed to store authorization code")
return
}
oauthRedirect(w, r, redirectURI, state, codeStr, "")
}
// Token handles POST /oauth/token — exchanges authorization code for access token.
func (o *OAuthHandlers) Token(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
tokenError(w, "invalid_request", "Invalid form data")
return
}
grantType := r.FormValue("grant_type")
if grantType != "authorization_code" {
tokenError(w, "unsupported_grant_type", "Only authorization_code is supported")
return
}
codeStr := r.FormValue("code")
redirectURI := r.FormValue("redirect_uri")
clientID := r.FormValue("client_id")
codeVerifier := r.FormValue("code_verifier")
if codeStr == "" || clientID == "" || codeVerifier == "" {
tokenError(w, "invalid_request", "Missing required parameters")
return
}
// Consume the code (marks used, checks expiry)
code, err := lib.OAuthCodeConsume(o.DB, codeStr)
if err != nil {
tokenError(w, "invalid_grant", "Invalid or expired authorization code")
return
}
// Verify client_id matches
if code.ClientID != clientID {
tokenError(w, "invalid_grant", "Client mismatch")
return
}
// Verify redirect_uri matches
if code.RedirectURI != redirectURI {
tokenError(w, "invalid_grant", "Redirect URI mismatch")
return
}
// Verify PKCE: SHA256(code_verifier) must match code_challenge
verifierHash := sha256.Sum256([]byte(codeVerifier))
computedChallenge := base64.RawURLEncoding.EncodeToString(verifierHash[:])
if computedChallenge != code.CodeChallenge {
tokenError(w, "invalid_grant", "PKCE verification failed")
return
}
// Generate access token
tokenBytes := make([]byte, 64)
if _, err := rand.Read(tokenBytes); err != nil {
tokenError(w, "server_error", "Failed to generate token")
return
}
tokenStr := hex.EncodeToString(tokenBytes)
now := time.Now()
expiresIn := int64(24 * 60 * 60) // 24 hours in seconds
oauthToken := &lib.OAuthToken{
Token: tokenStr,
ClientID: clientID,
UserID: code.UserID,
Scope: code.Scope,
ExpiresAt: now.Add(24 * time.Hour).UnixMilli(),
Revoked: false,
CreatedAt: now.UnixMilli(),
}
if err := lib.OAuthTokenCreate(o.DB, oauthToken); err != nil {
tokenError(w, "server_error", "Failed to store token")
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Cache-Control", "no-store")
json.NewEncoder(w).Encode(map[string]any{
"access_token": tokenStr,
"token_type": "Bearer",
"expires_in": expiresIn,
})
}
// Revoke handles POST /oauth/revoke — revokes an access token (RFC 7009).
func (o *OAuthHandlers) Revoke(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusOK)
return
}
tokenStr := r.FormValue("token")
if tokenStr != "" {
_ = lib.OAuthTokenRevoke(o.DB, tokenStr)
}
// Always return 200 per RFC 7009
w.WriteHeader(http.StatusOK)
}
// extractUserID checks JWT from Authorization header or ds_token cookie.
func (o *OAuthHandlers) extractUserID(r *http.Request) string {
// Try Authorization header first
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
token := strings.TrimPrefix(auth, "Bearer ")
claims, err := validateJWT(token, o.Cfg.JWTSecret)
if err == nil {
session, err := lib.SessionByID(o.DB, claims.SessionID)
if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() {
return claims.UserID
}
}
}
// Try cookie
cookie, err := r.Cookie("ds_token")
if err == nil && cookie.Value != "" {
claims, err := validateJWT(cookie.Value, o.Cfg.JWTSecret)
if err == nil {
session, err := lib.SessionByID(o.DB, claims.SessionID)
if err == nil && session != nil && !session.Revoked && session.ExpiresAt >= time.Now().UnixMilli() {
return claims.UserID
}
}
}
return ""
}
func (o *OAuthHandlers) serveConsentPage(w http.ResponseWriter, appName, authorizeURI string) {
// Extract query string from the authorize URI
parsed, _ := url.Parse(authorizeURI)
queryString := parsed.RawQuery
layoutCandidates := []string{
"portal/templates/layouts/auth.html",
"/opt/dealspace/portal/templates/layouts/auth.html",
}
pageCandidates := []string{
"portal/templates/auth/consent.html",
"/opt/dealspace/portal/templates/auth/consent.html",
}
var layoutPath, pagePath string
for _, p := range layoutCandidates {
if _, err := os.Stat(p); err == nil {
layoutPath = p
break
}
}
for _, p := range pageCandidates {
if _, err := os.Stat(p); err == nil {
pagePath = p
break
}
}
if layoutPath == "" || pagePath == "" {
http.Error(w, "Consent template not found", http.StatusInternalServerError)
return
}
tmpl, err := template.ParseFiles(layoutPath, pagePath)
if err != nil {
http.Error(w, "Template parse error: "+err.Error(), http.StatusInternalServerError)
return
}
data := struct {
Title string
AppName string
OriginalQuery string
}{
Title: "Authorize Application — Dealspace",
AppName: appName,
OriginalQuery: queryString,
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl.ExecuteTemplate(w, "layout", data)
}
func validRedirectURI(registered []string, uri string) bool {
parsed, err := url.Parse(uri)
if err != nil {
return false
}
for _, r := range registered {
rParsed, err := url.Parse(r)
if err != nil {
continue
}
// Match scheme + host (port-agnostic for localhost)
if parsed.Scheme == rParsed.Scheme && parsed.Hostname() == rParsed.Hostname() {
return true
}
}
return false
}
func oauthError(w http.ResponseWriter, redirectURI, state, errCode, errDesc string) {
if redirectURI == "" {
ErrorResponse(w, http.StatusBadRequest, errCode, errDesc)
return
}
oauthRedirect(w, nil, redirectURI, state, "", errCode)
}
func oauthRedirect(w http.ResponseWriter, r *http.Request, redirectURI, state, code, errCode string) {
u, err := url.Parse(redirectURI)
if err != nil {
ErrorResponse(w, http.StatusBadRequest, "invalid_request", "Invalid redirect_uri")
return
}
q := u.Query()
if code != "" {
q.Set("code", code)
}
if errCode != "" {
q.Set("error", errCode)
}
if state != "" {
q.Set("state", state)
}
u.RawQuery = q.Encode()
if r != nil {
http.Redirect(w, r, u.String(), http.StatusFound)
} else {
w.Header().Set("Location", u.String())
w.WriteHeader(http.StatusFound)
}
}
func tokenError(w http.ResponseWriter, errCode, errDesc string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
json.NewEncoder(w).Encode(map[string]string{
"error": errCode,
"error_description": errDesc,
})
}
// SeedOAuthClient creates the default Claude OAuth client if it doesn't exist.
func SeedOAuthClient(db *lib.DB) {
client := &lib.OAuthClient{
ClientID: "claude",
ClientName: "Claude",
RedirectURIs: []string{
"http://localhost",
"http://127.0.0.1",
"http://localhost:0",
"http://127.0.0.1:0",
},
CreatedAt: time.Now().UnixMilli(),
}
_ = lib.OAuthClientCreate(db, client)
}