405 lines
12 KiB
Go
405 lines
12 KiB
Go
package main
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"inou/lib"
|
|
)
|
|
|
|
// OAuth 2.0 Authorization Server Implementation
|
|
// Supports: Authorization Code flow with PKCE (for Flutter/mobile)
|
|
//
|
|
// Endpoints:
|
|
// GET /oauth/authorize - Authorization endpoint (shows consent, redirects with code)
|
|
// POST /oauth/token - Token endpoint (exchanges code for tokens, refreshes)
|
|
// GET /oauth/userinfo - UserInfo endpoint (returns user profile)
|
|
// POST /oauth/revoke - Revoke refresh token
|
|
//
|
|
// Token Strategy:
|
|
// - Access tokens: 15 minutes, stateless (encrypted blob)
|
|
// - Refresh tokens: 30 days, DB-stored, rotated on each use
|
|
|
|
const (
|
|
accessTokenDuration = 15 * time.Minute
|
|
refreshTokenDuration = 30 * 24 * time.Hour
|
|
)
|
|
|
|
// oauthError returns an OAuth 2.0 compliant error response
|
|
func oauthError(w http.ResponseWriter, err, desc string, code int) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
w.WriteHeader(code)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": err,
|
|
"error_description": desc,
|
|
})
|
|
}
|
|
|
|
// oauthJSON returns a JSON response with no-cache headers
|
|
func oauthJSON(w http.ResponseWriter, data any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Cache-Control", "no-store")
|
|
json.NewEncoder(w).Encode(data)
|
|
}
|
|
|
|
// handleOAuthAuthorize handles GET /oauth/authorize
|
|
// Parameters: client_id, redirect_uri, response_type, state, code_challenge, code_challenge_method
|
|
func handleOAuthAuthorize(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
oauthError(w, "invalid_request", "Method must be GET", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Parse parameters
|
|
clientID := r.URL.Query().Get("client_id")
|
|
redirectURI := r.URL.Query().Get("redirect_uri")
|
|
responseType := r.URL.Query().Get("response_type")
|
|
state := r.URL.Query().Get("state")
|
|
codeChallenge := r.URL.Query().Get("code_challenge")
|
|
codeChallengeMethod := r.URL.Query().Get("code_challenge_method")
|
|
|
|
// Validate required parameters
|
|
if clientID == "" {
|
|
oauthError(w, "invalid_request", "client_id is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if redirectURI == "" {
|
|
oauthError(w, "invalid_request", "redirect_uri is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if responseType != "code" {
|
|
oauthError(w, "unsupported_response_type", "Only 'code' response type is supported", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate client
|
|
client, err := lib.OAuthClientGet(clientID)
|
|
if err != nil {
|
|
oauthError(w, "invalid_client", "Unknown client_id", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate redirect URI
|
|
if !lib.OAuthClientValidRedirectURI(client, redirectURI) {
|
|
oauthError(w, "invalid_request", "Invalid redirect_uri for this client", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Validate PKCE if provided
|
|
if codeChallenge != "" && codeChallengeMethod != "S256" {
|
|
oauthError(w, "invalid_request", "Only S256 code_challenge_method is supported", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Check if user is logged in
|
|
dossier := getLoggedInDossier(r)
|
|
if dossier == nil {
|
|
// Store return URL in cookie and redirect to login
|
|
returnURL := r.URL.String()
|
|
http.SetCookie(w, &http.Cookie{
|
|
Name: "oauth_return",
|
|
Value: returnURL,
|
|
Path: "/",
|
|
MaxAge: 600, // 10 minutes
|
|
HttpOnly: true,
|
|
Secure: true,
|
|
SameSite: http.SameSiteLaxMode,
|
|
})
|
|
http.Redirect(w, r, "/start", http.StatusSeeOther)
|
|
return
|
|
}
|
|
|
|
// User is logged in - generate authorization code
|
|
code, err := lib.OAuthCodeCreate(
|
|
clientID,
|
|
dossier.DossierID,
|
|
redirectURI,
|
|
codeChallenge,
|
|
codeChallengeMethod,
|
|
)
|
|
if err != nil {
|
|
oauthError(w, "server_error", "Failed to create authorization code", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Build redirect URL with code
|
|
redirectURL, _ := url.Parse(redirectURI)
|
|
q := redirectURL.Query()
|
|
q.Set("code", code.Code)
|
|
if state != "" {
|
|
q.Set("state", state)
|
|
}
|
|
redirectURL.RawQuery = q.Encode()
|
|
|
|
http.Redirect(w, r, redirectURL.String(), http.StatusSeeOther)
|
|
}
|
|
|
|
// handleOAuthToken handles POST /oauth/token
|
|
// Supports: authorization_code, refresh_token grant types
|
|
func handleOAuthToken(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
oauthError(w, "invalid_request", "Method must be POST", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Parse form or JSON body
|
|
contentType := r.Header.Get("Content-Type")
|
|
var grantType, code, redirectURI, clientID, clientSecret, refreshToken, codeVerifier string
|
|
|
|
if strings.Contains(contentType, "application/json") {
|
|
var body map[string]string
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
oauthError(w, "invalid_request", "Invalid JSON body", http.StatusBadRequest)
|
|
return
|
|
}
|
|
grantType = body["grant_type"]
|
|
code = body["code"]
|
|
redirectURI = body["redirect_uri"]
|
|
clientID = body["client_id"]
|
|
clientSecret = body["client_secret"]
|
|
refreshToken = body["refresh_token"]
|
|
codeVerifier = body["code_verifier"]
|
|
} else {
|
|
r.ParseForm()
|
|
grantType = r.FormValue("grant_type")
|
|
code = r.FormValue("code")
|
|
redirectURI = r.FormValue("redirect_uri")
|
|
clientID = r.FormValue("client_id")
|
|
clientSecret = r.FormValue("client_secret")
|
|
refreshToken = r.FormValue("refresh_token")
|
|
codeVerifier = r.FormValue("code_verifier")
|
|
}
|
|
|
|
switch grantType {
|
|
case "authorization_code":
|
|
handleAuthorizationCodeGrant(w, clientID, clientSecret, code, redirectURI, codeVerifier)
|
|
case "refresh_token":
|
|
handleRefreshTokenGrant(w, clientID, clientSecret, refreshToken)
|
|
default:
|
|
oauthError(w, "unsupported_grant_type", "Only authorization_code and refresh_token grants are supported", http.StatusBadRequest)
|
|
}
|
|
}
|
|
|
|
// handleAuthorizationCodeGrant exchanges an authorization code for tokens
|
|
func handleAuthorizationCodeGrant(w http.ResponseWriter, clientID, clientSecret, code, redirectURI, codeVerifier string) {
|
|
// Validate client
|
|
client, err := lib.OAuthClientGet(clientID)
|
|
if err != nil {
|
|
oauthError(w, "invalid_client", "Unknown client_id", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Verify client secret (if provided - public clients may not have one)
|
|
if clientSecret != "" && !lib.OAuthClientVerifySecret(client, clientSecret) {
|
|
oauthError(w, "invalid_client", "Invalid client_secret", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get and validate code
|
|
authCode, err := lib.OAuthCodeGet(code)
|
|
if err != nil {
|
|
oauthError(w, "invalid_grant", "Invalid or expired authorization code", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify code belongs to this client
|
|
if authCode.ClientID != clientID {
|
|
oauthError(w, "invalid_grant", "Code was not issued to this client", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify redirect URI matches
|
|
if authCode.RedirectURI != redirectURI {
|
|
oauthError(w, "invalid_grant", "redirect_uri mismatch", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify PKCE
|
|
if !lib.OAuthCodeVerifyPKCE(authCode, codeVerifier) {
|
|
oauthError(w, "invalid_grant", "Invalid code_verifier", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Mark code as used
|
|
if err := lib.OAuthCodeUse(code); err != nil {
|
|
oauthError(w, "server_error", "Failed to consume code", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Generate tokens
|
|
accessToken := lib.TokenCreate(authCode.DossierID, accessTokenDuration)
|
|
refreshTokenObj, err := lib.OAuthRefreshTokenCreate(clientID, authCode.DossierID)
|
|
if err != nil {
|
|
oauthError(w, "server_error", "Failed to create refresh token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
oauthJSON(w, map[string]any{
|
|
"access_token": accessToken,
|
|
"token_type": "Bearer",
|
|
"expires_in": int(accessTokenDuration.Seconds()),
|
|
"refresh_token": refreshTokenObj.TokenID,
|
|
})
|
|
}
|
|
|
|
// handleRefreshTokenGrant exchanges a refresh token for new tokens
|
|
func handleRefreshTokenGrant(w http.ResponseWriter, clientID, clientSecret, refreshToken string) {
|
|
// Validate client
|
|
client, err := lib.OAuthClientGet(clientID)
|
|
if err != nil {
|
|
oauthError(w, "invalid_client", "Unknown client_id", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Verify client secret (if provided)
|
|
if clientSecret != "" && !lib.OAuthClientVerifySecret(client, clientSecret) {
|
|
oauthError(w, "invalid_client", "Invalid client_secret", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get and validate refresh token
|
|
oldToken, err := lib.OAuthRefreshTokenGet(refreshToken)
|
|
if err != nil {
|
|
oauthError(w, "invalid_grant", "Invalid or expired refresh token", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Verify token belongs to this client
|
|
if oldToken.ClientID != clientID {
|
|
oauthError(w, "invalid_grant", "Token was not issued to this client", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Rotate refresh token (revoke old, create new)
|
|
newRefreshToken, err := lib.OAuthRefreshTokenRotate(refreshToken)
|
|
if err != nil {
|
|
oauthError(w, "server_error", "Failed to rotate refresh token", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// Generate new access token
|
|
accessToken := lib.TokenCreate(oldToken.DossierID, accessTokenDuration)
|
|
|
|
oauthJSON(w, map[string]any{
|
|
"access_token": accessToken,
|
|
"token_type": "Bearer",
|
|
"expires_in": int(accessTokenDuration.Seconds()),
|
|
"refresh_token": newRefreshToken.TokenID,
|
|
})
|
|
}
|
|
|
|
// handleOAuthUserinfo handles GET /oauth/userinfo
|
|
// Returns the authenticated user's profile
|
|
func handleOAuthUserinfo(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
oauthError(w, "invalid_request", "Method must be GET", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
// Get bearer token
|
|
auth := r.Header.Get("Authorization")
|
|
if !strings.HasPrefix(auth, "Bearer ") {
|
|
w.Header().Set("WWW-Authenticate", `Bearer realm="inou"`)
|
|
oauthError(w, "invalid_token", "Bearer token required", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
tokenStr := strings.TrimPrefix(auth, "Bearer ")
|
|
token, err := lib.TokenParse(tokenStr)
|
|
if err != nil {
|
|
w.Header().Set("WWW-Authenticate", `Bearer realm="inou", error="invalid_token"`)
|
|
oauthError(w, "invalid_token", "Invalid or expired token", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Get dossier
|
|
dossier, err := lib.DossierGet("", token.DossierID) // nil ctx - internal operation
|
|
if err != nil || dossier == nil {
|
|
oauthError(w, "invalid_token", "User not found", http.StatusUnauthorized)
|
|
return
|
|
}
|
|
|
|
// Return OpenID Connect compatible userinfo
|
|
oauthJSON(w, map[string]any{
|
|
"sub": dossier.DossierID,
|
|
"name": dossier.Name,
|
|
"email": dossier.Email,
|
|
})
|
|
}
|
|
|
|
// handleOAuthRevoke handles POST /oauth/revoke
|
|
// Revokes a refresh token
|
|
func handleOAuthRevoke(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
oauthError(w, "invalid_request", "Method must be POST", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
|
|
r.ParseForm()
|
|
token := r.FormValue("token")
|
|
if token == "" {
|
|
oauthError(w, "invalid_request", "token is required", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
// Revoke the token (ignore errors - RFC 7009 says always return 200)
|
|
lib.OAuthRefreshTokenRevoke(token)
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
// RegisterOAuthRoutes registers OAuth endpoints
|
|
func RegisterOAuthRoutes(mux *http.ServeMux) {
|
|
mux.HandleFunc("/oauth/authorize", handleOAuthAuthorize)
|
|
mux.HandleFunc("/oauth/token", handleOAuthToken)
|
|
mux.HandleFunc("/oauth/userinfo", handleOAuthUserinfo)
|
|
mux.HandleFunc("/oauth/revoke", handleOAuthRevoke)
|
|
}
|
|
|
|
// CreateAnthropicClient creates the OAuth client for Anthropic/Claude
|
|
// Call this once during setup
|
|
func CreateAnthropicClient() error {
|
|
// Check if already exists
|
|
_, err := lib.OAuthClientGet("anthropic")
|
|
if err == nil {
|
|
return nil // Already exists
|
|
}
|
|
|
|
// Create client with Anthropic's callback URLs
|
|
redirectURIs := []string{
|
|
"https://claude.ai/api/mcp/auth_callback",
|
|
"https://claude.com/api/mcp/auth_callback",
|
|
"http://localhost:6274/oauth/callback",
|
|
"http://localhost:6274/oauth/callback/debug",
|
|
}
|
|
|
|
client, secret, err := lib.OAuthClientCreate("Anthropic Claude", redirectURIs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
fmt.Printf("Created Anthropic OAuth client:\n")
|
|
fmt.Printf(" Client ID: %s\n", client.ClientID)
|
|
fmt.Printf(" Client Secret: %s\n", secret)
|
|
fmt.Printf(" (Save the secret - it cannot be retrieved later)\n")
|
|
|
|
return nil
|
|
}
|
|
|
|
// EnsureBridgeClient creates the OAuth client for the MCP bridge (public client, no secret)
|
|
// Called on startup to ensure the client exists
|
|
func EnsureBridgeClient() error {
|
|
_, err := lib.OAuthClientGet("inou-bridge")
|
|
if err == nil {
|
|
return nil // Already exists
|
|
}
|
|
|
|
// Create public client (no redirect URIs needed - uses refresh_token grant only)
|
|
return lib.OAuthClientCreatePublic("inou-bridge", "Inou Bridge")
|
|
}
|