241 lines
5.7 KiB
Go
241 lines
5.7 KiB
Go
package lib
|
|
|
|
// ============================================================================
|
|
// Auth Database (auth.db) - OAuth tokens and sessions
|
|
// ============================================================================
|
|
// Separate from medical data (inou.db). Volatile/ephemeral data.
|
|
// Schema documented in docs/schema-auth.sql
|
|
// ============================================================================
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
var authDB *sql.DB
|
|
|
|
// AuthDBInit opens the auth database connection
|
|
func AuthDBInit(dbPath string) error {
|
|
var err error
|
|
authDB, err = sql.Open("sqlite3", dbPath)
|
|
return err
|
|
}
|
|
|
|
// AuthDBClose closes the auth database connection
|
|
func AuthDBClose() {
|
|
if authDB != nil {
|
|
authDB.Close()
|
|
}
|
|
}
|
|
|
|
// authSave inserts or updates a record in auth.db (simplified, no encryption)
|
|
func authSave(table string, v interface{}) error {
|
|
val := reflect.ValueOf(v)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
typ := val.Type()
|
|
|
|
var cols []string
|
|
var placeholders []string
|
|
var vals []interface{}
|
|
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
tag := field.Tag.Get("db")
|
|
if tag == "" || tag == "-" {
|
|
continue
|
|
}
|
|
parts := strings.Split(tag, ",")
|
|
colName := parts[0]
|
|
cols = append(cols, colName)
|
|
placeholders = append(placeholders, "?")
|
|
vals = append(vals, val.Field(i).Interface())
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
|
|
table,
|
|
strings.Join(cols, ", "),
|
|
strings.Join(placeholders, ", "),
|
|
)
|
|
|
|
_, err := authDB.Exec(query, vals...)
|
|
return err
|
|
}
|
|
|
|
// authLoad retrieves a single record by primary key from auth.db
|
|
func authLoad(table string, pk interface{}, dest interface{}) error {
|
|
val := reflect.ValueOf(dest)
|
|
if val.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("dest must be a pointer")
|
|
}
|
|
val = val.Elem()
|
|
typ := val.Type()
|
|
|
|
var pkCol string
|
|
var cols []string
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
tag := field.Tag.Get("db")
|
|
if tag == "" || tag == "-" {
|
|
continue
|
|
}
|
|
parts := strings.Split(tag, ",")
|
|
colName := parts[0]
|
|
isPK := len(parts) > 1 && parts[1] == "pk"
|
|
if isPK {
|
|
pkCol = colName
|
|
}
|
|
cols = append(cols, colName)
|
|
}
|
|
|
|
query := fmt.Sprintf("SELECT %s FROM %s WHERE %s = ?", strings.Join(cols, ", "), table, pkCol)
|
|
row := authDB.QueryRow(query, pk)
|
|
|
|
// Build scan destinations
|
|
ptrs := make([]interface{}, len(cols))
|
|
colIdx := 0
|
|
for i := 0; i < typ.NumField(); i++ {
|
|
field := typ.Field(i)
|
|
tag := field.Tag.Get("db")
|
|
if tag == "" || tag == "-" {
|
|
continue
|
|
}
|
|
ptrs[colIdx] = val.Field(i).Addr().Interface()
|
|
colIdx++
|
|
}
|
|
|
|
return row.Scan(ptrs...)
|
|
}
|
|
|
|
// authQuery executes a SELECT and scans into a slice (for auth.db)
|
|
func authQuery(query string, args []interface{}, dest interface{}) error {
|
|
rows, err := authDB.Query(query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
sliceVal := reflect.ValueOf(dest)
|
|
if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice {
|
|
return fmt.Errorf("dest must be pointer to slice")
|
|
}
|
|
sliceVal = sliceVal.Elem()
|
|
elemType := sliceVal.Type().Elem()
|
|
isPtr := elemType.Kind() == reflect.Ptr
|
|
if isPtr {
|
|
elemType = elemType.Elem()
|
|
}
|
|
|
|
cols, _ := rows.Columns()
|
|
|
|
for rows.Next() {
|
|
elem := reflect.New(elemType).Elem()
|
|
ptrs := make([]interface{}, len(cols))
|
|
|
|
for i, col := range cols {
|
|
for j := 0; j < elemType.NumField(); j++ {
|
|
field := elemType.Field(j)
|
|
tag := field.Tag.Get("db")
|
|
if tag == "" {
|
|
continue
|
|
}
|
|
parts := strings.Split(tag, ",")
|
|
if parts[0] == col {
|
|
ptrs[i] = elem.Field(j).Addr().Interface()
|
|
break
|
|
}
|
|
}
|
|
if ptrs[i] == nil {
|
|
var dummy interface{}
|
|
ptrs[i] = &dummy
|
|
}
|
|
}
|
|
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return err
|
|
}
|
|
|
|
if isPtr {
|
|
sliceVal.Set(reflect.Append(sliceVal, elem.Addr()))
|
|
} else {
|
|
sliceVal.Set(reflect.Append(sliceVal, elem))
|
|
}
|
|
}
|
|
|
|
return rows.Err()
|
|
}
|
|
|
|
// authDelete removes a record by primary key from auth.db
|
|
func authDelete(table, pkCol string, pkVal interface{}) error {
|
|
query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol)
|
|
_, err := authDB.Exec(query, pkVal)
|
|
return err
|
|
}
|
|
|
|
// ============================================================================
|
|
// Session Management
|
|
// ============================================================================
|
|
|
|
// SessionCreate creates a new session with token and expiry
|
|
type Session struct {
|
|
Token string `db:"token,pk"`
|
|
DossierID string `db:"dossier_id"`
|
|
CreatedAt int64 `db:"created_at"`
|
|
ExpiresAt int64 `db:"expires_at"`
|
|
}
|
|
|
|
func SessionCreate(token, dossierID string, maxAgeSeconds int) error {
|
|
now := time.Now().Unix()
|
|
s := &Session{
|
|
Token: token,
|
|
DossierID: dossierID,
|
|
CreatedAt: now,
|
|
ExpiresAt: now + int64(maxAgeSeconds),
|
|
}
|
|
return authSave("sessions", s)
|
|
}
|
|
|
|
func SessionDelete(token string) error {
|
|
return authDelete("sessions", "token", token)
|
|
}
|
|
|
|
// FIX TASK-018: Delete all sessions for a dossier (session rotation on login)
|
|
func SessionDeleteByDossier(dossierID string) error {
|
|
_, err := authDB.Exec("DELETE FROM sessions WHERE dossier_id = ?", dossierID)
|
|
return err
|
|
}
|
|
|
|
func SessionGetDossier(token string) string {
|
|
var s Session
|
|
if err := authLoad("sessions", token, &s); err != nil {
|
|
return ""
|
|
}
|
|
// Check expiry
|
|
if time.Now().Unix() > s.ExpiresAt {
|
|
SessionDelete(token)
|
|
return ""
|
|
}
|
|
return s.DossierID
|
|
}
|
|
|
|
func SessionCleanup() {
|
|
authDB.Exec("DELETE FROM sessions WHERE expires_at < ?", time.Now().Unix())
|
|
}
|
|
|
|
// DossierGetBySessionToken gets dossier by session token (for API auth)
|
|
func DossierGetBySessionToken(token string) *Dossier {
|
|
dossierID := SessionGetDossier(token)
|
|
if dossierID == "" {
|
|
return nil
|
|
}
|
|
d, _ := DossierGet(dossierID, dossierID)
|
|
return d
|
|
}
|