inou/lib/db_queries.go

942 lines
24 KiB
Go

package lib
// ============================================================================
// ⛔ CRITICAL: DO NOT MODIFY THIS FILE WITHOUT JOHAN'S EXPRESS CONSENT
// ============================================================================
// This is the ONLY file allowed to access the database directly.
// All other code must use these functions: Save, Load, Query, Delete, Count
//
// Run `make check-db` to verify no direct DB access exists elsewhere.
// ============================================================================
import (
"crypto/rand"
"crypto/sha256"
"database/sql"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"reflect"
"strings"
"time"
)
// Store provides a generic data layer with automatic encryption.
// String and []byte fields are encrypted automatically.
// Int/int64/bool fields pass through unchanged.
//
// Struct tags:
// db:"column_name" - maps field to column (default: lowercase field name)
// db:"column_name,pk" - marks as primary key
// db:"-" - skip this field
//
// Example:
// type Entry struct {
// EntryID int64 `db:"entry_id,pk"`
// Type string `db:"type"` // auto-encrypted
// Status int `db:"status"` // not encrypted
// }
// fieldInfo holds metadata about a struct field
type fieldInfo struct {
Name string // struct field name
Column string // db column name
Type reflect.Type // Go type
IsPK bool // is primary key
Skip bool // skip this field
Index int // field index in struct
}
// tableInfo holds metadata about a struct/table mapping
type tableInfo struct {
Name string
Fields []fieldInfo
PK *fieldInfo
}
// getTableInfo extracts table metadata from a struct using reflection
func getTableInfo(table string, v any) (*tableInfo, error) {
t := reflect.TypeOf(v)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("expected struct, got %s", t.Kind())
}
info := &tableInfo{Name: table}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Parse db tag
tag := field.Tag.Get("db")
if tag == "-" {
continue
}
fi := fieldInfo{
Name: field.Name,
Type: field.Type,
Index: i,
}
// Default: lowercase field name
fi.Column = strings.ToLower(field.Name)
// Parse tag options (override column name if specified)
if tag != "" {
parts := strings.Split(tag, ",")
if parts[0] != "" {
fi.Column = parts[0]
}
for _, opt := range parts[1:] {
if opt == "pk" {
fi.IsPK = true
}
}
}
info.Fields = append(info.Fields, fi)
if fi.IsPK {
info.PK = &info.Fields[len(info.Fields)-1]
}
}
// PK required for table operations, optional for QuerySQL (table="")
if table != "" && info.PK == nil {
return nil, fmt.Errorf("no primary key defined for table %s (use `db:\"col,pk\"`)", table)
}
return info, nil
}
// goTypeToSQLite maps Go types to SQLite column types
func goTypeToSQLite(t reflect.Type) string {
switch t.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return "INTEGER"
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "INTEGER"
case reflect.Float32, reflect.Float64:
return "REAL"
case reflect.Bool:
return "INTEGER"
case reflect.String:
return "TEXT"
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
return "BLOB"
}
return "TEXT" // JSON-encode other slices
default:
return "TEXT"
}
}
// Verify checks that the database schema matches the struct definition.
// Returns nil if OK, or an error with suggested ALTER statements.
func Verify(table string, v any) error {
info, err := getTableInfo(table, v)
if err != nil {
return err
}
// Get current columns from DB
rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
if err != nil {
return fmt.Errorf("failed to query table info: %w", err)
}
defer rows.Close()
dbColumns := make(map[string]string) // column -> type
for rows.Next() {
var cid int
var name, colType string
var notNull, pk int
var dflt sql.NullString
if err := rows.Scan(&cid, &name, &colType, &notNull, &dflt, &pk); err != nil {
return err
}
dbColumns[name] = colType
}
// Check if table exists
if len(dbColumns) == 0 {
// Generate CREATE TABLE
var cols []string
for _, f := range info.Fields {
colDef := fmt.Sprintf("%s %s", f.Column, goTypeToSQLite(f.Type))
if f.IsPK {
colDef += " PRIMARY KEY"
}
cols = append(cols, colDef)
}
return fmt.Errorf(`schema mismatch: table '%s' does not exist
Suggested fix:
CREATE TABLE %s (
%s
);
`, table, table, strings.Join(cols, ",\n "))
}
// Check for missing columns
var missing []fieldInfo
for _, f := range info.Fields {
if _, exists := dbColumns[f.Column]; !exists {
missing = append(missing, f)
}
}
if len(missing) > 0 {
var alters []string
var details []string
for _, f := range missing {
sqlType := goTypeToSQLite(f.Type)
details = append(details, fmt.Sprintf(" - %s (%s)", f.Column, sqlType))
alters = append(alters, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table, f.Column, sqlType))
}
return fmt.Errorf(`schema mismatch for table '%s'
Missing columns:
%s
Suggested fix:
%s
Run these manually, then restart.
`, table, strings.Join(details, "\n"), strings.Join(alters, "\n "))
}
return nil
}
// VerifyAll checks multiple table/struct pairs
func VerifyAll(pairs ...any) error {
if len(pairs)%2 != 0 {
return fmt.Errorf("VerifyAll requires pairs of (table string, struct)")
}
for i := 0; i < len(pairs); i += 2 {
table, ok := pairs[i].(string)
if !ok {
return fmt.Errorf("expected string for table name at position %d", i)
}
if err := Verify(table, pairs[i+1]); err != nil {
return err
}
}
return nil
}
// Save upserts struct(s) to the database.
// Accepts a single struct or a slice of structs.
// String and []byte fields are encrypted automatically.
// Slices are wrapped in a transaction for atomicity.
func Save(table string, v any) error {
start := time.Now()
defer func() { logSlowQuery("INSERT OR REPLACE INTO "+table, time.Since(start)) }()
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Slice: bulk upsert with transaction
if val.Kind() == reflect.Slice {
if val.Len() == 0 {
return nil
}
// Get table info from first element
elemType := val.Type().Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
sample := reflect.New(elemType).Interface()
info, err := getTableInfo(table, sample)
if err != nil {
return err
}
// Build query once
var columns []string
var placeholders []string
for _, f := range info.Fields {
columns = append(columns, f.Column)
placeholders = append(placeholders, "?")
}
query := fmt.Sprintf(
"INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
table,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "),
)
// Transaction with prepared statement
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
stmt, err := tx.Prepare(query)
if err != nil {
return err
}
defer stmt.Close()
for i := 0; i < val.Len(); i++ {
elem := val.Index(i)
if elem.Kind() == reflect.Ptr {
elem = elem.Elem()
}
values := make([]any, len(info.Fields))
for j, f := range info.Fields {
values[j] = encryptField(elem.Field(f.Index), f.Column)
}
if _, err := stmt.Exec(values...); err != nil {
return err
}
}
return tx.Commit()
}
// Single struct
info, err := getTableInfo(table, v)
if err != nil {
return err
}
var columns []string
var placeholders []string
var values []any
for _, f := range info.Fields {
columns = append(columns, f.Column)
placeholders = append(placeholders, "?")
values = append(values, encryptField(val.Field(f.Index), f.Column))
}
query := fmt.Sprintf(
"INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
table,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "),
)
_, err = db.Exec(query, values...)
return err
}
// Load retrieves a record by primary key and populates the struct.
// String and []byte fields are decrypted automatically.
func Load(table string, id string, v any) error {
start := time.Now()
defer func() { logSlowQuery("SELECT FROM "+table+" WHERE pk=?", time.Since(start), id) }()
info, err := getTableInfo(table, v)
if err != nil {
return err
}
val := reflect.ValueOf(v)
if val.Kind() != reflect.Ptr {
return fmt.Errorf("Load requires a pointer to struct")
}
val = val.Elem()
// Build SELECT
var columns []string
for _, f := range info.Fields {
columns = append(columns, f.Column)
}
query := fmt.Sprintf(
"SELECT %s FROM %s WHERE %s = ?",
strings.Join(columns, ", "),
table,
info.PK.Column,
)
row := db.QueryRow(query, id)
// Create scan destinations
scanDest := make([]any, len(info.Fields))
for i, f := range info.Fields {
scanDest[i] = createScanDest(f.Type)
}
if err := row.Scan(scanDest...); err != nil {
return err
}
// Decrypt and set fields
for i, f := range info.Fields {
decryptAndSet(val.Field(f.Index), scanDest[i], f.Type, f.Column)
}
return nil
}
// Query runs a SQL query and populates the slice.
// Column names in result must match struct db tags.
// String and []byte fields are decrypted automatically.
func Query(query string, args []any, slicePtr any) error {
start := time.Now()
defer func() { logSlowQuery(query, time.Since(start), args...) }()
sliceVal := reflect.ValueOf(slicePtr)
if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice {
return fmt.Errorf("Query requires a pointer to slice")
}
sliceType := sliceVal.Elem().Type()
elemType := sliceType.Elem()
if elemType.Kind() == reflect.Ptr {
elemType = elemType.Elem()
}
// Get struct field info
sample := reflect.New(elemType).Interface()
info, err := getTableInfo("", sample)
if err != nil {
return err
}
// Build column->field mapping
colToField := make(map[string]*fieldInfo)
for i := range info.Fields {
colToField[info.Fields[i].Column] = &info.Fields[i]
}
rows, err := db.Query(query, args...)
if err != nil {
return err
}
defer rows.Close()
// Get column names from result
cols, err := rows.Columns()
if err != nil {
return err
}
result := reflect.MakeSlice(sliceType, 0, 0)
for rows.Next() {
elem := reflect.New(elemType).Elem()
// Create scan destinations for each column
scanDest := make([]any, len(cols))
fieldMap := make([]*fieldInfo, len(cols))
for i, col := range cols {
if fi, ok := colToField[col]; ok {
scanDest[i] = createScanDest(fi.Type)
fieldMap[i] = fi
} else {
// Unknown column - scan into throwaway
scanDest[i] = new(any)
}
}
if err := rows.Scan(scanDest...); err != nil {
return err
}
// Decrypt and set matched fields
for i, fi := range fieldMap {
if fi != nil {
decryptAndSet(elem.Field(fi.Index), scanDest[i], fi.Type, fi.Column)
}
}
if sliceType.Elem().Kind() == reflect.Ptr {
result = reflect.Append(result, elem.Addr())
} else {
result = reflect.Append(result, elem)
}
}
sliceVal.Elem().Set(result)
return nil
}
// Count runs a SELECT COUNT(*) query and returns the result.
// Example: Count("SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?", dossierID, category)
func Count(query string, args ...any) (int, error) {
var count int
err := db.QueryRow(query, args...).Scan(&count)
return count, err
}
// Delete removes a record by primary key.
// pkCol is the primary key column name, id is 16-char hex string.
func Delete(table, pkCol, id string) error {
query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol)
_, err := db.Exec(query, id)
return err
}
// DeleteTree removes a record and all its descendants.
// Traverses the parent-child hierarchy recursively, deletes children first.
// Works with any SQL database (no CTEs or CASCADE needed).
func DeleteTree(table, pkCol, parentCol, id string) error {
// Collect all IDs (parent + descendants)
var ids []string
var collect func(string) error
collect = func(pid string) error {
ids = append(ids, pid)
rows, err := db.Query(
fmt.Sprintf("SELECT %s FROM %s WHERE %s = ?", pkCol, table, parentCol),
pid,
)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
var childID string
if err := rows.Scan(&childID); err != nil {
return err
}
if err := collect(childID); err != nil {
return err
}
}
return nil
}
if err := collect(id); err != nil {
return err
}
// Delete in reverse order (children first)
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
for i := len(ids) - 1; i >= 0; i-- {
if _, err := tx.Exec(
fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol),
ids[i],
); err != nil {
return err
}
}
return tx.Commit()
}
// encryptField encrypts string and []byte fields, passes others through
// ID columns (ending in _id) are NOT encrypted
func encryptField(v reflect.Value, column string) any {
switch v.Kind() {
case reflect.String:
s := v.String()
if s == "" {
return ""
}
// Don't encrypt ID columns
if strings.HasSuffix(column, "_id") {
return s
}
return CryptoEncrypt(s)
case reflect.Slice:
if v.Type().Elem().Kind() == reflect.Uint8 {
// []byte
b := v.Bytes()
if len(b) == 0 {
return []byte{}
}
return CryptoEncryptBytes(b)
}
// Other slices: JSON encode then encrypt
return v.Interface()
default:
return v.Interface()
}
}
// createScanDest creates an appropriate scan destination for a Go type
func createScanDest(t reflect.Type) any {
switch t.Kind() {
case reflect.Int, reflect.Int64:
return new(sql.NullInt64)
case reflect.String:
return new(sql.NullString)
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
return new([]byte)
}
return new(sql.NullString)
case reflect.Bool:
return new(sql.NullInt64)
default:
return new(sql.NullString)
}
}
// decryptAndSet decrypts the scanned value and sets the struct field
// ID columns (ending in _id) are NOT decrypted
func decryptAndSet(field reflect.Value, scanned any, t reflect.Type, column string) {
switch t.Kind() {
case reflect.Int, reflect.Int64:
if ns, ok := scanned.(*sql.NullInt64); ok && ns.Valid {
field.SetInt(ns.Int64)
}
case reflect.String:
if ns, ok := scanned.(*sql.NullString); ok && ns.Valid {
// Don't decrypt ID columns or known plain-text columns
plainCols := map[string]bool{"language": true, "timezone": true, "weight_unit": true, "height_unit": true}
if strings.HasSuffix(column, "_id") || plainCols[column] {
field.SetString(ns.String)
} else {
field.SetString(CryptoDecrypt(ns.String))
}
}
case reflect.Slice:
if t.Elem().Kind() == reflect.Uint8 {
// []byte - decrypt
if b, ok := scanned.(*[]byte); ok && b != nil && len(*b) > 0 {
decrypted, err := CryptoDecryptBytes(*b)
if err == nil {
field.SetBytes(decrypted)
}
}
}
case reflect.Bool:
if ns, ok := scanned.(*sql.NullInt64); ok && ns.Valid {
field.SetBool(ns.Int64 != 0)
}
}
}
// =============================================================================
// OAuth 2.0 Store Functions
// =============================================================================
// OAuthClient represents a registered OAuth client (Claude, Flutter, etc.)
type OAuthClient struct {
ClientID string `db:"client_id,pk"`
ClientSecret string `db:"client_secret"` // hashed
Name string `db:"name"`
RedirectURIs []string `db:"-"` // parsed from JSON
RedirectJSON string `db:"redirect_uris"`
CreatedAt int64 `db:"created_at"`
}
// OAuthCode represents a temporary authorization code
type OAuthCode struct {
Code string `db:"code,pk"`
ClientID string `db:"client_id"`
DossierID string `db:"dossier_id"`
RedirectURI string `db:"redirect_uri"`
CodeChallenge string `db:"code_challenge"`
CodeChallengeMethod string `db:"code_challenge_method"`
ExpiresAt int64 `db:"expires_at"`
Used int `db:"used"`
}
// OAuthRefreshToken represents a long-lived refresh token
type OAuthRefreshToken struct {
TokenID string `db:"token_id,pk"`
ClientID string `db:"client_id"`
DossierID string `db:"dossier_id"`
ExpiresAt int64 `db:"expires_at"`
Revoked int `db:"revoked"`
CreatedAt int64 `db:"created_at"`
}
// generateID creates a random hex ID of specified byte length
func generateID(bytes int) string {
b := make([]byte, bytes)
rand.Read(b)
return hex.EncodeToString(b)
}
// hashSecret hashes a client secret for storage
func hashSecret(secret string) string {
h := sha256.Sum256([]byte(secret))
return hex.EncodeToString(h[:])
}
// OAuthClientCreate creates a new OAuth client and returns the plain secret
func OAuthClientCreate(name string, redirectURIs []string) (*OAuthClient, string, error) {
clientID := generateID(16)
plainSecret := generateID(32)
urisJSON, _ := json.Marshal(redirectURIs)
client := &OAuthClient{
ClientID: clientID,
ClientSecret: hashSecret(plainSecret),
Name: name,
RedirectURIs: redirectURIs,
RedirectJSON: string(urisJSON),
CreatedAt: time.Now().Unix(),
}
if err := authSave("oauth_clients", client); err != nil {
return nil, "", err
}
return client, plainSecret, nil
}
// OAuthClientCreatePublic creates a public OAuth client with a fixed client_id (no secret)
// Used for first-party clients like the MCP bridge
func OAuthClientCreatePublic(clientID, name string) error {
urisJSON, _ := json.Marshal([]string{})
client := &OAuthClient{
ClientID: clientID,
ClientSecret: "", // Public client - no secret
Name: name,
RedirectURIs: []string{},
RedirectJSON: string(urisJSON),
CreatedAt: time.Now().Unix(),
}
return authSave("oauth_clients", client)
}
// OAuthClientGet retrieves a client by ID
func OAuthClientGet(clientID string) (*OAuthClient, error) {
var client OAuthClient
if err := authLoad("oauth_clients", clientID, &client); err != nil {
return nil, err
}
// Parse redirect URIs
json.Unmarshal([]byte(client.RedirectJSON), &client.RedirectURIs)
return &client, nil
}
// OAuthClientVerifySecret checks if the provided secret matches
func OAuthClientVerifySecret(client *OAuthClient, secret string) bool {
return client.ClientSecret == hashSecret(secret)
}
// OAuthClientValidRedirectURI checks if URI is registered for client
func OAuthClientValidRedirectURI(client *OAuthClient, uri string) bool {
for _, u := range client.RedirectURIs {
if u == uri {
return true
}
}
return false
}
// OAuthCodeCreate creates a new authorization code (valid for 10 minutes)
func OAuthCodeCreate(clientID, dossierID, redirectURI, codeChallenge, codeChallengeMethod string) (*OAuthCode, error) {
code := &OAuthCode{
Code: generateID(32),
ClientID: clientID,
DossierID: dossierID,
RedirectURI: redirectURI,
CodeChallenge: codeChallenge,
CodeChallengeMethod: codeChallengeMethod,
ExpiresAt: time.Now().Unix() + 600, // 10 minutes
Used: 0,
}
if err := authSave("oauth_codes", code); err != nil {
return nil, err
}
return code, nil
}
// OAuthCodeGet retrieves and validates a code (checks expiry, not used)
func OAuthCodeGet(code string) (*OAuthCode, error) {
var c OAuthCode
if err := authLoad("oauth_codes", code, &c); err != nil {
return nil, err
}
if c.Used != 0 {
return nil, fmt.Errorf("code already used")
}
if time.Now().Unix() > c.ExpiresAt {
return nil, fmt.Errorf("code expired")
}
return &c, nil
}
// OAuthCodeUse marks a code as used (single-use)
func OAuthCodeUse(code string) error {
var c OAuthCode
if err := authLoad("oauth_codes", code, &c); err != nil {
return err
}
c.Used = 1
return authSave("oauth_codes", &c)
}
// OAuthCodeVerifyPKCE verifies the PKCE code_verifier against stored challenge
func OAuthCodeVerifyPKCE(c *OAuthCode, codeVerifier string) bool {
if c.CodeChallenge == "" {
return true // No PKCE required
}
if c.CodeChallengeMethod != "S256" {
return false // Only support S256
}
// S256: BASE64URL(SHA256(code_verifier)) == code_challenge
h := sha256.Sum256([]byte(codeVerifier))
computed := base64.RawURLEncoding.EncodeToString(h[:])
return computed == c.CodeChallenge
}
// OAuthRefreshTokenCreate creates a new refresh token (valid for 30 days)
func OAuthRefreshTokenCreate(clientID, dossierID string) (*OAuthRefreshToken, error) {
token := &OAuthRefreshToken{
TokenID: generateID(32),
ClientID: clientID,
DossierID: dossierID,
ExpiresAt: time.Now().Unix() + 30*24*60*60, // 30 days
Revoked: 0,
CreatedAt: time.Now().Unix(),
}
if err := authSave("oauth_refresh_tokens", token); err != nil {
return nil, err
}
return token, nil
}
// OAuthRefreshTokenGet retrieves and validates a refresh token
func OAuthRefreshTokenGet(tokenID string) (*OAuthRefreshToken, error) {
var t OAuthRefreshToken
if err := authLoad("oauth_refresh_tokens", tokenID, &t); err != nil {
return nil, err
}
if t.Revoked != 0 {
return nil, fmt.Errorf("token revoked")
}
if time.Now().Unix() > t.ExpiresAt {
return nil, fmt.Errorf("token expired")
}
return &t, nil
}
// OAuthRefreshTokenRevoke revokes a refresh token
func OAuthRefreshTokenRevoke(tokenID string) error {
var t OAuthRefreshToken
if err := authLoad("oauth_refresh_tokens", tokenID, &t); err != nil {
return err
}
t.Revoked = 1
return authSave("oauth_refresh_tokens", &t)
}
// OAuthRefreshTokenRotate revokes old token and creates new one
func OAuthRefreshTokenRotate(oldTokenID string) (*OAuthRefreshToken, error) {
old, err := OAuthRefreshTokenGet(oldTokenID)
if err != nil {
return nil, err
}
// Revoke old
if err := OAuthRefreshTokenRevoke(oldTokenID); err != nil {
return nil, err
}
// Create new
return OAuthRefreshTokenCreate(old.ClientID, old.DossierID)
}
// OAuthRefreshTokenRevokeAll revokes all refresh tokens for a dossier
func OAuthRefreshTokenRevokeAll(dossierID string) error {
var tokens []*OAuthRefreshToken
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE dossier_id = ? AND revoked = 0", []any{dossierID}, &tokens); err != nil {
return err
}
for _, t := range tokens {
t.Revoked = 1
if err := authSave("oauth_refresh_tokens", t); err != nil {
return err
}
}
return nil
}
// OAuthRefreshTokenGetForClient gets an existing valid refresh token for a client/dossier pair
func OAuthRefreshTokenGetForClient(clientID, dossierID string) (*OAuthRefreshToken, error) {
var tokens []*OAuthRefreshToken
now := time.Now().Unix()
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE client_id = ? AND dossier_id = ? AND revoked = 0 AND expires_at > ? ORDER BY created_at DESC LIMIT 1", []any{clientID, dossierID, now}, &tokens); err != nil {
return nil, err
}
if len(tokens) == 0 {
return nil, fmt.Errorf("no token found")
}
return tokens[0], nil
}
// OAuthRefreshTokenGetOrCreate gets or creates a refresh token for a client/dossier pair
func OAuthRefreshTokenGetOrCreate(clientID, dossierID string) (*OAuthRefreshToken, error) {
token, err := OAuthRefreshTokenGetForClient(clientID, dossierID)
if err == nil {
return token, nil
}
return OAuthRefreshTokenCreate(clientID, dossierID)
}
// OAuthRefreshTokenRegenerate revokes all existing tokens for a client/dossier and creates new one
func OAuthRefreshTokenRegenerate(clientID, dossierID string) (*OAuthRefreshToken, error) {
var tokens []*OAuthRefreshToken
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE client_id = ? AND dossier_id = ? AND revoked = 0", []any{clientID, dossierID}, &tokens); err != nil {
return nil, err
}
for _, t := range tokens {
t.Revoked = 1
authSave("oauth_refresh_tokens", t)
}
return OAuthRefreshTokenCreate(clientID, dossierID)
}
// OAuthCleanup removes expired codes and tokens (call periodically)
func OAuthCleanup() error {
now := time.Now().Unix()
// Delete expired or used codes (keep used codes 1 hour for debugging)
var codes []*OAuthCode
if err := authQuery("SELECT * FROM oauth_codes WHERE expires_at < ? OR used = 1", []any{now - 3600}, &codes); err == nil {
for _, c := range codes {
authDelete("oauth_codes", "code", c.Code)
}
}
// Delete expired revoked tokens (keep revoked 1 day)
var tokens []*OAuthRefreshToken
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE expires_at < ? AND revoked = 1", []any{now - 86400}, &tokens); err == nil {
for _, t := range tokens {
authDelete("oauth_refresh_tokens", "token_id", t.TokenID)
}
}
return nil
}