messaging-center/internal/api/oauth.go

221 lines
5.4 KiB
Go

package api
import (
"crypto/hmac"
"crypto/sha256"
"crypto/subtle"
"encoding/base64"
"encoding/json"
"net/http"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/inou-ai/messaging-center/internal/core"
)
// TokenInfo represents an issued access token.
type TokenInfo struct {
ClientID string
Scopes []string
ExpiresAt time.Time
}
// OAuthProvider handles OAuth 2.0 authentication.
type OAuthProvider struct {
config *core.OAuthConfig
tokens map[string]*TokenInfo
mu sync.RWMutex
}
// NewOAuthProvider creates a new OAuth provider.
func NewOAuthProvider(cfg *core.OAuthConfig) *OAuthProvider {
return &OAuthProvider{
config: cfg,
tokens: make(map[string]*TokenInfo),
}
}
// TokenResponse is the OAuth token endpoint response.
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Scope string `json:"scope,omitempty"`
}
// TokenError is the OAuth error response.
type TokenError struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
// HandleToken handles POST /oauth/token.
func (p *OAuthProvider) HandleToken(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeJSON(w, http.StatusMethodNotAllowed, TokenError{Error: "method_not_allowed"})
return
}
if err := r.ParseForm(); err != nil {
writeJSON(w, http.StatusBadRequest, TokenError{Error: "invalid_request"})
return
}
grantType := r.FormValue("grant_type")
if grantType != "client_credentials" {
writeJSON(w, http.StatusBadRequest, TokenError{
Error: "unsupported_grant_type",
Description: "Only client_credentials is supported",
})
return
}
// Get client credentials from Basic auth or form
clientID, clientSecret, ok := r.BasicAuth()
if !ok {
clientID = r.FormValue("client_id")
clientSecret = r.FormValue("client_secret")
}
if clientID == "" || clientSecret == "" {
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_client"})
return
}
// Find and validate client
var client *core.OAuthClient
for i := range p.config.Clients {
if p.config.Clients[i].ID == clientID {
client = &p.config.Clients[i]
break
}
}
if client == nil || !secureCompare(client.Secret, clientSecret) {
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_client"})
return
}
// Parse requested scopes
requestedScopes := strings.Fields(r.FormValue("scope"))
var grantedScopes []string
if len(requestedScopes) == 0 {
// Grant all client scopes
grantedScopes = client.Scopes
} else {
// Grant only requested scopes that client has
for _, s := range requestedScopes {
if client.HasScope(s) {
grantedScopes = append(grantedScopes, s)
}
}
}
// Generate token
token := generateToken()
expiresAt := time.Now().Add(p.config.AccessTokenTTL)
p.mu.Lock()
p.tokens[token] = &TokenInfo{
ClientID: clientID,
Scopes: grantedScopes,
ExpiresAt: expiresAt,
}
p.mu.Unlock()
writeJSON(w, http.StatusOK, TokenResponse{
AccessToken: token,
TokenType: "Bearer",
ExpiresIn: int(p.config.AccessTokenTTL.Seconds()),
Scope: strings.Join(grantedScopes, " "),
})
}
// ValidateToken validates a bearer token and returns its info.
func (p *OAuthProvider) ValidateToken(token string) *TokenInfo {
p.mu.RLock()
defer p.mu.RUnlock()
info, ok := p.tokens[token]
if !ok {
return nil
}
if time.Now().After(info.ExpiresAt) {
return nil
}
return info
}
// RequireAuth returns middleware that requires authentication.
func (p *OAuthProvider) RequireAuth(scopes ...string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if auth == "" {
w.Header().Set("WWW-Authenticate", "Bearer")
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "unauthorized"})
return
}
parts := strings.SplitN(auth, " ", 2)
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_token"})
return
}
info := p.ValidateToken(parts[1])
if info == nil {
writeJSON(w, http.StatusUnauthorized, TokenError{Error: "invalid_token"})
return
}
// Check required scopes
for _, required := range scopes {
hasScope := false
for _, s := range info.Scopes {
if s == required || s == "admin" {
hasScope = true
break
}
}
if !hasScope {
writeJSON(w, http.StatusForbidden, TokenError{
Error: "insufficient_scope",
Description: "Required scope: " + required,
})
return
}
}
// Store token info in context
ctx := r.Context()
ctx = withTokenInfo(ctx, info)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// generateToken creates a secure random token.
func generateToken() string {
id := uuid.New()
h := hmac.New(sha256.New, []byte("mc-token-salt"))
h.Write([]byte(id.String()))
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
}
// secureCompare compares two strings in constant time.
func secureCompare(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(v)
}