clavitor/clavitor.ai/oauth.go

294 lines
8.6 KiB
Go

// Social login (Google, Apple, Meta).
//
// All three implement the OAuth 2.0 / OIDC authorization-code flow.
// Each provider needs three env vars:
//
// GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET, GOOGLE_REDIRECT_URL
// APPLE_CLIENT_ID, APPLE_CLIENT_SECRET, APPLE_REDIRECT_URL
// META_CLIENT_ID, META_CLIENT_SECRET, META_REDIRECT_URL
//
// If a provider's vars aren't set, its /auth/<provider>/start route returns
// a "not configured" message and the button on /signup is functionally inert
// (still rendered for layout consistency).
//
// All three providers, on success, drop the user into the same onboarding flow
// as email signup: a customer is created in the TLW, the session cookie is set,
// and the user is redirected to /onboarding/profile.
package main
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
"time"
)
// oauthProvider holds everything we need to talk to one OIDC/OAuth2 provider.
type oauthProvider struct {
Name string
ClientID string
ClientSecret string
RedirectURL string
AuthURL string // authorization endpoint
TokenURL string // token endpoint
Scope string // space-separated
UserinfoURL string // optional — for providers without id_token email
EmailFromIDT bool // if true, decode email from id_token instead of userinfo
}
var oauthProviders map[string]*oauthProvider
func initOAuth() {
oauthProviders = map[string]*oauthProvider{
"google": {
Name: "google",
ClientID: os.Getenv("GOOGLE_CLIENT_ID"),
ClientSecret: os.Getenv("GOOGLE_CLIENT_SECRET"),
RedirectURL: os.Getenv("GOOGLE_REDIRECT_URL"),
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth",
TokenURL: "https://oauth2.googleapis.com/token",
Scope: "openid email profile",
UserinfoURL: "https://openidconnect.googleapis.com/v1/userinfo",
},
"apple": {
Name: "apple",
ClientID: os.Getenv("APPLE_CLIENT_ID"),
ClientSecret: os.Getenv("APPLE_CLIENT_SECRET"),
RedirectURL: os.Getenv("APPLE_REDIRECT_URL"),
AuthURL: "https://appleid.apple.com/auth/authorize",
TokenURL: "https://appleid.apple.com/auth/token",
Scope: "name email",
EmailFromIDT: true,
},
"meta": {
Name: "meta",
ClientID: os.Getenv("META_CLIENT_ID"),
ClientSecret: os.Getenv("META_CLIENT_SECRET"),
RedirectURL: os.Getenv("META_REDIRECT_URL"),
AuthURL: "https://www.facebook.com/v18.0/dialog/oauth",
TokenURL: "https://graph.facebook.com/v18.0/oauth/access_token",
Scope: "email public_profile",
UserinfoURL: "https://graph.facebook.com/me?fields=email,name",
},
}
for name, p := range oauthProviders {
if p.configured() {
log.Printf("OAuth: %s configured", name)
}
}
}
func (p *oauthProvider) configured() bool {
return p.ClientID != "" && p.ClientSecret != "" && p.RedirectURL != ""
}
// handleOAuthStart: GET /auth/{provider}/start
// Redirects the browser to the provider's consent screen.
func handleOAuthStart(w http.ResponseWriter, r *http.Request) {
name := strings.TrimPrefix(r.URL.Path, "/auth/")
name = strings.TrimSuffix(name, "/start")
p, ok := oauthProviders[name]
if !ok {
http.NotFound(w, r)
return
}
if !p.configured() {
http.Error(w, fmt.Sprintf("%s sign-in is not configured yet — try email", name), http.StatusServiceUnavailable)
return
}
// CSRF state — random nonce stored in a short-lived cookie that the
// callback verifies before exchanging the code.
state := randomState()
http.SetCookie(w, &http.Cookie{
Name: "clv_oauth_state_" + name,
Value: state,
Path: "/",
MaxAge: 600,
HttpOnly: true,
Secure: !devMode,
SameSite: http.SameSiteLaxMode,
})
q := url.Values{}
q.Set("client_id", p.ClientID)
q.Set("redirect_uri", p.RedirectURL)
q.Set("response_type", "code")
q.Set("scope", p.Scope)
q.Set("state", state)
if name == "apple" {
q.Set("response_mode", "form_post")
}
http.Redirect(w, r, p.AuthURL+"?"+q.Encode(), http.StatusFound)
}
// handleOAuthCallback: GET (or POST for Apple) /auth/{provider}/callback
// Exchanges the code for a token, fetches the email, and starts the onboarding session.
func handleOAuthCallback(w http.ResponseWriter, r *http.Request) {
name := strings.TrimPrefix(r.URL.Path, "/auth/")
name = strings.TrimSuffix(name, "/callback")
p, ok := oauthProviders[name]
if !ok {
http.NotFound(w, r)
return
}
if !p.configured() {
http.Error(w, "provider not configured", http.StatusServiceUnavailable)
return
}
// Apple uses form_post — read params from form. Others use query string.
var code, state string
if r.Method == "POST" {
r.ParseForm()
code = r.FormValue("code")
state = r.FormValue("state")
} else {
code = r.URL.Query().Get("code")
state = r.URL.Query().Get("state")
}
if code == "" {
http.Error(w, "missing code", http.StatusBadRequest)
return
}
// Verify CSRF state cookie
c, err := r.Cookie("clv_oauth_state_" + name)
if err != nil || c.Value == "" || c.Value != state {
http.Error(w, "invalid state", http.StatusBadRequest)
return
}
// Exchange the code for a token
tokenForm := url.Values{}
tokenForm.Set("grant_type", "authorization_code")
tokenForm.Set("code", code)
tokenForm.Set("client_id", p.ClientID)
tokenForm.Set("client_secret", p.ClientSecret)
tokenForm.Set("redirect_uri", p.RedirectURL)
req, err := http.NewRequest("POST", p.TokenURL, strings.NewReader(tokenForm.Encode()))
if err != nil {
http.Error(w, "token request failed", 500)
return
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 15 * time.Second}
resp, err := client.Do(req)
if err != nil {
log.Printf("oauth %s token exchange failed: %v", name, err)
http.Error(w, "token exchange failed", 502)
return
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
log.Printf("oauth %s token endpoint %d: %s", name, resp.StatusCode, string(body))
http.Error(w, "token exchange rejected", 502)
return
}
var tokResp struct {
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
}
json.Unmarshal(body, &tokResp)
// Get the email
var email string
if p.EmailFromIDT && tokResp.IDToken != "" {
email = emailFromIDToken(tokResp.IDToken)
} else if p.UserinfoURL != "" && tokResp.AccessToken != "" {
email = fetchEmailFromUserinfo(p.UserinfoURL, tokResp.AccessToken)
}
if email == "" {
log.Printf("oauth %s: could not get email from token/userinfo", name)
http.Error(w, "could not retrieve email from provider", 502)
return
}
email = strings.ToLower(email)
// TLW: create or fetch the customer for this email
if corpDB == nil {
http.Error(w, "onboarding offline", 503)
return
}
ourID, _, found := findCustomerByEmail(email)
if !found {
res := tlwCreateCustomer(map[string]interface{}{"email": email})
if res.Err != nil {
log.Printf("oauth %s: tlwCreateCustomer for %s: %v", name, email, res.Err)
http.Error(w, "could not create account", 500)
return
}
ourID = res.OurID
}
setOnboardingCookie(w, &onboardingState{
Email: email,
Provider: name,
CustomerID: ourID,
})
// Clear the CSRF state cookie
http.SetCookie(w, &http.Cookie{Name: "clv_oauth_state_" + name, Value: "", Path: "/", MaxAge: -1})
http.Redirect(w, r, "/onboarding/profile", http.StatusSeeOther)
}
func randomState() string {
buf := make([]byte, 24)
rand.Read(buf)
return base64.RawURLEncoding.EncodeToString(buf)
}
// emailFromIDToken decodes the unverified payload of a JWT id_token.
// We trust it because the token came directly from the provider's token
// endpoint over TLS — there's no third party to forge it. (For paranoid
// production use, swap in a JWKS-verifying library.)
func emailFromIDToken(idToken string) string {
parts := strings.Split(idToken, ".")
if len(parts) != 3 {
return ""
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
// Some providers use std (padded) base64
payload, err = base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return ""
}
}
var claims struct {
Email string `json:"email"`
}
json.Unmarshal(payload, &claims)
return claims.Email
}
func fetchEmailFromUserinfo(endpoint, accessToken string) string {
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return ""
}
req.Header.Set("Authorization", "Bearer "+accessToken)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return ""
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
var info struct {
Email string `json:"email"`
}
json.Unmarshal(body, &info)
return info.Email
}