221 lines
5.4 KiB
Go
221 lines
5.4 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"crypto/subtle"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/inou-ai/messaging-center/internal/core"
|
|
)
|
|
|
|
// TokenInfo represents an issued access token.
|
|
type TokenInfo struct {
|
|
ClientID string
|
|
Scopes []string
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// OAuthProvider handles OAuth 2.0 authentication.
|
|
type OAuthProvider struct {
|
|
config *core.OAuthConfig
|
|
tokens map[string]*TokenInfo
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewOAuthProvider creates a new OAuth provider.
|
|
func NewOAuthProvider(cfg *core.OAuthConfig) *OAuthProvider {
|
|
return &OAuthProvider{
|
|
config: cfg,
|
|
tokens: make(map[string]*TokenInfo),
|
|
}
|
|
}
|
|
|
|
// TokenResponse is the OAuth token endpoint response.
|
|
type TokenResponse struct {
|
|
AccessToken string `json:"access_token"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
Scope string `json:"scope,omitempty"`
|
|
}
|
|
|
|
// TokenError is the OAuth error response.
|
|
type TokenError struct {
|
|
Error string `json:"error"`
|
|
Description string `json:"error_description,omitempty"`
|
|
}
|
|
|
|
// HandleToken handles POST /oauth/token.
|
|
func (p *OAuthProvider) HandleToken(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
writeJSON(w, http.StatusMethodNotAllowed, TokenError{Error: "method_not_allowed"})
|
|
return
|
|
}
|
|
|
|
if err := r.ParseForm(); err != nil {
|
|
writeJSON(w, http.StatusBadRequest, TokenError{Error: "invalid_request"})
|
|
return
|
|
}
|
|
|
|
grantType := r.FormValue("grant_type")
|
|
if grantType != "client_credentials" {
|
|
writeJSON(w, http.StatusBadRequest, TokenError{
|
|
Error: "unsupported_grant_type",
|
|
Description: "Only client_credentials is supported",
|
|
})
|
|
return
|
|
}
|
|
|
|
// Get client credentials from Basic auth or form
|
|
clientID, clientSecret, ok := r.BasicAuth()
|
|
if !ok {
|
|
clientID = r.FormValue("client_id")
|
|
clientSecret = r.FormValue("client_secret")
|
|
}
|
|
|
|
if clientID == "" || clientSecret == "" {
|
|
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_client"})
|
|
return
|
|
}
|
|
|
|
// Find and validate client
|
|
var client *core.OAuthClient
|
|
for i := range p.config.Clients {
|
|
if p.config.Clients[i].ID == clientID {
|
|
client = &p.config.Clients[i]
|
|
break
|
|
}
|
|
}
|
|
|
|
if client == nil || !secureCompare(client.Secret, clientSecret) {
|
|
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_client"})
|
|
return
|
|
}
|
|
|
|
// Parse requested scopes
|
|
requestedScopes := strings.Fields(r.FormValue("scope"))
|
|
var grantedScopes []string
|
|
|
|
if len(requestedScopes) == 0 {
|
|
// Grant all client scopes
|
|
grantedScopes = client.Scopes
|
|
} else {
|
|
// Grant only requested scopes that client has
|
|
for _, s := range requestedScopes {
|
|
if client.HasScope(s) {
|
|
grantedScopes = append(grantedScopes, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Generate token
|
|
token := generateToken()
|
|
expiresAt := time.Now().Add(p.config.AccessTokenTTL)
|
|
|
|
p.mu.Lock()
|
|
p.tokens[token] = &TokenInfo{
|
|
ClientID: clientID,
|
|
Scopes: grantedScopes,
|
|
ExpiresAt: expiresAt,
|
|
}
|
|
p.mu.Unlock()
|
|
|
|
writeJSON(w, http.StatusOK, TokenResponse{
|
|
AccessToken: token,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: int(p.config.AccessTokenTTL.Seconds()),
|
|
Scope: strings.Join(grantedScopes, " "),
|
|
})
|
|
}
|
|
|
|
// ValidateToken validates a bearer token and returns its info.
|
|
func (p *OAuthProvider) ValidateToken(token string) *TokenInfo {
|
|
p.mu.RLock()
|
|
defer p.mu.RUnlock()
|
|
|
|
info, ok := p.tokens[token]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if time.Now().After(info.ExpiresAt) {
|
|
return nil
|
|
}
|
|
|
|
return info
|
|
}
|
|
|
|
// RequireAuth returns middleware that requires authentication.
|
|
func (p *OAuthProvider) RequireAuth(scopes ...string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.Header.Get("Authorization")
|
|
if auth == "" {
|
|
w.Header().Set("WWW-Authenticate", "Bearer")
|
|
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "unauthorized"})
|
|
return
|
|
}
|
|
|
|
parts := strings.SplitN(auth, " ", 2)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
|
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_token"})
|
|
return
|
|
}
|
|
|
|
info := p.ValidateToken(parts[1])
|
|
if info == nil {
|
|
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_token"})
|
|
return
|
|
}
|
|
|
|
// Check required scopes
|
|
for _, required := range scopes {
|
|
hasScope := false
|
|
for _, s := range info.Scopes {
|
|
if s == required || s == "admin" {
|
|
hasScope = true
|
|
break
|
|
}
|
|
}
|
|
if !hasScope {
|
|
writeJSON(w, http.StatusForbidden, TokenError{
|
|
Error: "insufficient_scope",
|
|
Description: "Required scope: " + required,
|
|
})
|
|
return
|
|
}
|
|
}
|
|
|
|
// Store token info in context
|
|
ctx := r.Context()
|
|
ctx = withTokenInfo(ctx, info)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// generateToken creates a secure random token.
|
|
func generateToken() string {
|
|
id := uuid.New()
|
|
h := hmac.New(sha256.New, []byte("mc-token-salt"))
|
|
h.Write([]byte(id.String()))
|
|
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
|
}
|
|
|
|
// secureCompare compares two strings in constant time.
|
|
func secureCompare(a, b string) bool {
|
|
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
|
|
}
|
|
|
|
func writeJSON(w http.ResponseWriter, status int, v any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(v)
|
|
}
|