inou/lib/db_auth.go

241 lines
5.7 KiB
Go

package lib
// ============================================================================
// Auth Database (auth.db) - OAuth tokens and sessions
// ============================================================================
// Separate from medical data (inou.db). Volatile/ephemeral data.
// Schema documented in docs/schema-auth.sql
// ============================================================================
import (
"database/sql"
"fmt"
"reflect"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
var authDB *sql.DB
// AuthDBInit opens the auth database connection
func AuthDBInit(dbPath string) error {
var err error
authDB, err = sql.Open("sqlite3", dbPath)
return err
}
// AuthDBClose closes the auth database connection
func AuthDBClose() {
if authDB != nil {
authDB.Close()
}
}
// authSave inserts or updates a record in auth.db (simplified, no encryption)
func authSave(table string, v interface{}) error {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
typ := val.Type()
var cols []string
var placeholders []string
var vals []interface{}
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
parts := strings.Split(tag, ",")
colName := parts[0]
cols = append(cols, colName)
placeholders = append(placeholders, "?")
vals = append(vals, val.Field(i).Interface())
}
query := fmt.Sprintf(
"INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
table,
strings.Join(cols, ", "),
strings.Join(placeholders, ", "),
)
_, err := authDB.Exec(query, vals...)
return err
}
// authLoad retrieves a single record by primary key from auth.db
func authLoad(table string, pk interface{}, dest interface{}) error {
val := reflect.ValueOf(dest)
if val.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
val = val.Elem()
typ := val.Type()
var pkCol string
var cols []string
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
parts := strings.Split(tag, ",")
colName := parts[0]
isPK := len(parts) > 1 && parts[1] == "pk"
if isPK {
pkCol = colName
}
cols = append(cols, colName)
}
query := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ?", strings.Join(cols, ", "), table, pkCol)
row := authDB.QueryRow(query, pk)
// Build scan destinations
ptrs := make([]interface{}, len(cols))
colIdx := 0
for i := 0; i < typ.NumField(); i++ {
field := typ.Field(i)
tag := field.Tag.Get("db")
if tag == "" || tag == "-" {
continue
}
ptrs[colIdx] = val.Field(i).Addr().Interface()
colIdx++
}
return row.Scan(ptrs...)
}
// authQuery executes a SELECT and scans into a slice (for auth.db)
func authQuery(query string, args []interface{}, dest interface{}) error {
rows, err := authDB.Query(query, args...)
if err != nil {
return err
}
defer rows.Close()
sliceVal := reflect.ValueOf(dest)
if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice {
return fmt.Errorf("dest must be pointer to slice")
}
sliceVal = sliceVal.Elem()
elemType := sliceVal.Type().Elem()
isPtr := elemType.Kind() == reflect.Ptr
if isPtr {
elemType = elemType.Elem()
}
cols, _ := rows.Columns()
for rows.Next() {
elem := reflect.New(elemType).Elem()
ptrs := make([]interface{}, len(cols))
for i, col := range cols {
for j := 0; j < elemType.NumField(); j++ {
field := elemType.Field(j)
tag := field.Tag.Get("db")
if tag == "" {
continue
}
parts := strings.Split(tag, ",")
if parts[0] == col {
ptrs[i] = elem.Field(j).Addr().Interface()
break
}
}
if ptrs[i] == nil {
var dummy interface{}
ptrs[i] = &dummy
}
}
if err := rows.Scan(ptrs...); err != nil {
return err
}
if isPtr {
sliceVal.Set(reflect.Append(sliceVal, elem.Addr()))
} else {
sliceVal.Set(reflect.Append(sliceVal, elem))
}
}
return rows.Err()
}
// authDelete removes a record by primary key from auth.db
func authDelete(table, pkCol string, pkVal interface{}) error {
query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol)
_, err := authDB.Exec(query, pkVal)
return err
}
// ============================================================================
// Session Management
// ============================================================================
// SessionCreate creates a new session with token and expiry
type Session struct {
Token string `db:"token,pk"`
DossierID string `db:"dossier_id"`
CreatedAt int64 `db:"created_at"`
ExpiresAt int64 `db:"expires_at"`
}
func SessionCreate(token, dossierID string, maxAgeSeconds int) error {
now := time.Now().Unix()
s := &Session{
Token: token,
DossierID: dossierID,
CreatedAt: now,
ExpiresAt: now + int64(maxAgeSeconds),
}
return authSave("sessions", s)
}
func SessionDelete(token string) error {
return authDelete("sessions", "token", token)
}
// FIX TASK-018: Delete all sessions for a dossier (session rotation on login)
func SessionDeleteByDossier(dossierID string) error {
_, err := authDB.Exec("DELETE FROM sessions WHERE dossier_id = ?", dossierID)
return err
}
func SessionGetDossier(token string) string {
var s Session
if err := authLoad("sessions", token, &s); err != nil {
return ""
}
// Check expiry
if time.Now().Unix() > s.ExpiresAt {
SessionDelete(token)
return ""
}
return s.DossierID
}
func SessionCleanup() {
authDB.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().Unix())
}
// DossierGetBySessionToken gets dossier by session token (for API auth)
func DossierGetBySessionToken(token string) *Dossier {
dossierID := SessionGetDossier(token)
if dossierID == "" {
return nil
}
d, _ := DossierGet(dossierID, dossierID)
return d
}