1132 lines
29 KiB
Go
1132 lines
29 KiB
Go
package lib
|
|
|
|
// ============================================================================
|
|
// ⛔ CRITICAL: DO NOT MODIFY THIS FILE WITHOUT JOHAN'S EXPRESS CONSENT
|
|
// ============================================================================
|
|
// This is the ONLY file allowed to access the database directly.
|
|
// Internal DB functions (unexported): dbSave, dbLoad, dbQuery, dbDelete, dbCount
|
|
// External code must use RBAC-checked functions (EntryWrite, DossierGet, etc.)
|
|
//
|
|
// Run `make check-db` to verify no direct DB access exists elsewhere.
|
|
// ============================================================================
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Store provides a generic data layer with automatic encryption.
|
|
// String and []byte fields are encrypted automatically.
|
|
// Int/int64/bool fields pass through unchanged.
|
|
//
|
|
// Struct tags:
|
|
// db:"column_name" - maps field to column (default: lowercase field name)
|
|
// db:"column_name,pk" - marks as primary key
|
|
// db:"-" - skip this field
|
|
//
|
|
// Example:
|
|
// type Entry struct {
|
|
// EntryID int64 `db:"entry_id,pk"`
|
|
// Type string `db:"type"` // auto-encrypted
|
|
// Status int `db:"status"` // not encrypted
|
|
// }
|
|
|
|
// fieldInfo holds metadata about a struct field
|
|
type fieldInfo struct {
|
|
Name string // struct field name
|
|
Column string // db column name
|
|
Type reflect.Type // Go type
|
|
IsPK bool // is primary key
|
|
Skip bool // skip this field
|
|
Index int // field index in struct
|
|
}
|
|
|
|
// tableInfo holds metadata about a struct/table mapping
|
|
type tableInfo struct {
|
|
Name string
|
|
Fields []fieldInfo
|
|
PK *fieldInfo
|
|
}
|
|
|
|
// getTableInfo extracts table metadata from a struct using reflection
|
|
func getTableInfo(table string, v any) (*tableInfo, error) {
|
|
t := reflect.TypeOf(v)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
if t.Kind() != reflect.Struct {
|
|
return nil, fmt.Errorf("expected struct, got %s", t.Kind())
|
|
}
|
|
|
|
info := &tableInfo{Name: table}
|
|
|
|
for i := 0; i < t.NumField(); i++ {
|
|
field := t.Field(i)
|
|
|
|
// Parse db tag
|
|
tag := field.Tag.Get("db")
|
|
if tag == "-" {
|
|
continue
|
|
}
|
|
|
|
fi := fieldInfo{
|
|
Name: field.Name,
|
|
Type: field.Type,
|
|
Index: i,
|
|
}
|
|
|
|
// Default: lowercase field name
|
|
fi.Column = strings.ToLower(field.Name)
|
|
|
|
// Parse tag options (override column name if specified)
|
|
if tag != "" {
|
|
parts := strings.Split(tag, ",")
|
|
if parts[0] != "" {
|
|
fi.Column = parts[0]
|
|
}
|
|
for _, opt := range parts[1:] {
|
|
if opt == "pk" {
|
|
fi.IsPK = true
|
|
}
|
|
}
|
|
}
|
|
|
|
info.Fields = append(info.Fields, fi)
|
|
if fi.IsPK {
|
|
info.PK = &info.Fields[len(info.Fields)-1]
|
|
}
|
|
}
|
|
|
|
// PK required for table operations, optional for QuerySQL (table="")
|
|
if table != "" && info.PK == nil {
|
|
return nil, fmt.Errorf("no primary key defined for table %s (use `db:\"col,pk\"`)", table)
|
|
}
|
|
|
|
return info, nil
|
|
}
|
|
|
|
// goTypeToSQLite maps Go types to SQLite column types
|
|
func goTypeToSQLite(t reflect.Type) string {
|
|
switch t.Kind() {
|
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
|
return "INTEGER"
|
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
return "INTEGER"
|
|
case reflect.Float32, reflect.Float64:
|
|
return "REAL"
|
|
case reflect.Bool:
|
|
return "INTEGER"
|
|
case reflect.String:
|
|
return "TEXT"
|
|
case reflect.Slice:
|
|
if t.Elem().Kind() == reflect.Uint8 {
|
|
return "BLOB"
|
|
}
|
|
return "TEXT" // JSON-encode other slices
|
|
default:
|
|
return "TEXT"
|
|
}
|
|
}
|
|
|
|
// Verify checks that the database schema matches the struct definition.
|
|
// Returns nil if OK, or an error with suggested ALTER statements.
|
|
func Verify(table string, v any) error {
|
|
info, err := getTableInfo(table, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Get current columns from DB
|
|
rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to query table info: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
dbColumns := make(map[string]string) // column -> type
|
|
for rows.Next() {
|
|
var cid int
|
|
var name, colType string
|
|
var notNull, pk int
|
|
var dflt sql.NullString
|
|
if err := rows.Scan(&cid, &name, &colType, ¬Null, &dflt, &pk); err != nil {
|
|
return err
|
|
}
|
|
dbColumns[name] = colType
|
|
}
|
|
|
|
// Check if table exists
|
|
if len(dbColumns) == 0 {
|
|
// Generate CREATE TABLE
|
|
var cols []string
|
|
for _, f := range info.Fields {
|
|
colDef := fmt.Sprintf("%s %s", f.Column, goTypeToSQLite(f.Type))
|
|
if f.IsPK {
|
|
colDef += " PRIMARY KEY"
|
|
}
|
|
cols = append(cols, colDef)
|
|
}
|
|
return fmt.Errorf(`schema mismatch: table '%s' does not exist
|
|
|
|
Suggested fix:
|
|
CREATE TABLE %s (
|
|
%s
|
|
);
|
|
`, table, table, strings.Join(cols, ",\n "))
|
|
}
|
|
|
|
// Check for missing columns
|
|
var missing []fieldInfo
|
|
for _, f := range info.Fields {
|
|
if _, exists := dbColumns[f.Column]; !exists {
|
|
missing = append(missing, f)
|
|
}
|
|
}
|
|
|
|
if len(missing) > 0 {
|
|
var alters []string
|
|
var details []string
|
|
for _, f := range missing {
|
|
sqlType := goTypeToSQLite(f.Type)
|
|
details = append(details, fmt.Sprintf(" - %s (%s)", f.Column, sqlType))
|
|
alters = append(alters, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table, f.Column, sqlType))
|
|
}
|
|
return fmt.Errorf(`schema mismatch for table '%s'
|
|
|
|
Missing columns:
|
|
%s
|
|
|
|
Suggested fix:
|
|
%s
|
|
|
|
Run these manually, then restart.
|
|
`, table, strings.Join(details, "\n"), strings.Join(alters, "\n "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// VerifyAll checks multiple table/struct pairs
|
|
func VerifyAll(pairs ...any) error {
|
|
if len(pairs)%2 != 0 {
|
|
return fmt.Errorf("VerifyAll requires pairs of (table string, struct)")
|
|
}
|
|
for i := 0; i < len(pairs); i += 2 {
|
|
table, ok := pairs[i].(string)
|
|
if !ok {
|
|
return fmt.Errorf("expected string for table name at position %d", i)
|
|
}
|
|
if err := Verify(table, pairs[i+1]); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// dbSave upserts struct(s) to the database.
|
|
// Accepts a single struct or a slice of structs.
|
|
// String and []byte fields are encrypted automatically.
|
|
// Slices are wrapped in a transaction for atomicity.
|
|
func dbSave(table string, v any) error {
|
|
start := time.Now()
|
|
defer func() { logSlowQuery("INSERT OR REPLACE INTO "+table, time.Since(start)) }()
|
|
|
|
val := reflect.ValueOf(v)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
// Slice: bulk upsert with transaction
|
|
if val.Kind() == reflect.Slice {
|
|
if val.Len() == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Get table info from first element
|
|
elemType := val.Type().Elem()
|
|
if elemType.Kind() == reflect.Ptr {
|
|
elemType = elemType.Elem()
|
|
}
|
|
sample := reflect.New(elemType).Interface()
|
|
info, err := getTableInfo(table, sample)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build query once
|
|
var columns []string
|
|
var placeholders []string
|
|
for _, f := range info.Fields {
|
|
columns = append(columns, f.Column)
|
|
placeholders = append(placeholders, "?")
|
|
}
|
|
query := fmt.Sprintf(
|
|
"INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
|
|
table,
|
|
strings.Join(columns, ", "),
|
|
strings.Join(placeholders, ", "),
|
|
)
|
|
|
|
// Transaction with prepared statement
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
stmt, err := tx.Prepare(query)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stmt.Close()
|
|
|
|
for i := 0; i < val.Len(); i++ {
|
|
elem := val.Index(i)
|
|
if elem.Kind() == reflect.Ptr {
|
|
elem = elem.Elem()
|
|
}
|
|
|
|
values := make([]any, len(info.Fields))
|
|
for j, f := range info.Fields {
|
|
values[j] = encryptField(elem.Field(f.Index), f.Column)
|
|
}
|
|
|
|
if _, err := stmt.Exec(values...); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Single struct
|
|
info, err := getTableInfo(table, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var columns []string
|
|
var placeholders []string
|
|
var values []any
|
|
|
|
for _, f := range info.Fields {
|
|
columns = append(columns, f.Column)
|
|
placeholders = append(placeholders, "?")
|
|
values = append(values, encryptField(val.Field(f.Index), f.Column))
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
|
|
table,
|
|
strings.Join(columns, ", "),
|
|
strings.Join(placeholders, ", "),
|
|
)
|
|
|
|
_, err = db.Exec(query, values...)
|
|
return err
|
|
}
|
|
|
|
// dbLoad retrieves a record by primary key and populates the struct.
|
|
// String and []byte fields are decrypted automatically.
|
|
func dbLoad(table string, id string, v any) error {
|
|
start := time.Now()
|
|
defer func() { logSlowQuery("SELECT FROM "+table+" WHERE pk=?", time.Since(start), id) }()
|
|
|
|
info, err := getTableInfo(table, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
val := reflect.ValueOf(v)
|
|
if val.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("Load requires a pointer to struct")
|
|
}
|
|
val = val.Elem()
|
|
|
|
// Build SELECT
|
|
var columns []string
|
|
for _, f := range info.Fields {
|
|
columns = append(columns, f.Column)
|
|
}
|
|
|
|
query := fmt.Sprintf(
|
|
"SELECT %s FROM %s WHERE %s = ?",
|
|
strings.Join(columns, ", "),
|
|
table,
|
|
info.PK.Column,
|
|
)
|
|
|
|
row := db.QueryRow(query, id)
|
|
|
|
// Create scan destinations
|
|
scanDest := make([]any, len(info.Fields))
|
|
for i, f := range info.Fields {
|
|
scanDest[i] = createScanDest(f.Type)
|
|
}
|
|
|
|
if err := row.Scan(scanDest...); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Decrypt and set fields
|
|
for i, f := range info.Fields {
|
|
decryptAndSet(val.Field(f.Index), scanDest[i], f.Type, f.Column)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// dbQuery runs a SQL query and populates the slice.
|
|
// Column names in result must match struct db tags.
|
|
// String and []byte fields are decrypted automatically.
|
|
func dbQuery(query string, args []any, slicePtr any) error {
|
|
start := time.Now()
|
|
defer func() { logSlowQuery(query, time.Since(start), args...) }()
|
|
|
|
sliceVal := reflect.ValueOf(slicePtr)
|
|
if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice {
|
|
return fmt.Errorf("Query requires a pointer to slice")
|
|
}
|
|
|
|
sliceType := sliceVal.Elem().Type()
|
|
elemType := sliceType.Elem()
|
|
if elemType.Kind() == reflect.Ptr {
|
|
elemType = elemType.Elem()
|
|
}
|
|
|
|
// Get struct field info
|
|
sample := reflect.New(elemType).Interface()
|
|
info, err := getTableInfo("", sample)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build column->field mapping
|
|
colToField := make(map[string]*fieldInfo)
|
|
for i := range info.Fields {
|
|
colToField[info.Fields[i].Column] = &info.Fields[i]
|
|
}
|
|
|
|
rows, err := db.Query(query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
// Get column names from result
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result := reflect.MakeSlice(sliceType, 0, 0)
|
|
|
|
for rows.Next() {
|
|
elem := reflect.New(elemType).Elem()
|
|
|
|
// Create scan destinations for each column
|
|
scanDest := make([]any, len(cols))
|
|
fieldMap := make([]*fieldInfo, len(cols))
|
|
|
|
for i, col := range cols {
|
|
if fi, ok := colToField[col]; ok {
|
|
scanDest[i] = createScanDest(fi.Type)
|
|
fieldMap[i] = fi
|
|
} else {
|
|
// Unknown column - scan into throwaway
|
|
scanDest[i] = new(any)
|
|
}
|
|
}
|
|
|
|
if err := rows.Scan(scanDest...); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Decrypt and set matched fields
|
|
for i, fi := range fieldMap {
|
|
if fi != nil {
|
|
decryptAndSet(elem.Field(fi.Index), scanDest[i], fi.Type, fi.Column)
|
|
}
|
|
}
|
|
|
|
if sliceType.Elem().Kind() == reflect.Ptr {
|
|
result = reflect.Append(result, elem.Addr())
|
|
} else {
|
|
result = reflect.Append(result, elem)
|
|
}
|
|
}
|
|
|
|
sliceVal.Elem().Set(result)
|
|
return nil
|
|
}
|
|
|
|
// dbCount runs a SELECT COUNT(*) query and returns the result.
|
|
func dbCount(query string, args ...any) (int, error) {
|
|
var count int
|
|
err := db.QueryRow(query, args...).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
// dbDelete removes a record by primary key.
|
|
// pkCol is the primary key column name, id is 16-char hex string.
|
|
func dbDelete(table, pkCol, id string) error {
|
|
query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol)
|
|
_, err := db.Exec(query, id)
|
|
return err
|
|
}
|
|
|
|
// dbDeleteTree removes a record and all its descendants.
|
|
// Traverses the parent-child hierarchy recursively, deletes children first.
|
|
// Works with any SQL database (no CTEs or CASCADE needed).
|
|
func dbDeleteTree(table, pkCol, parentCol, id string) error {
|
|
// Collect all IDs (parent + descendants)
|
|
var ids []string
|
|
var collect func(string) error
|
|
collect = func(pid string) error {
|
|
ids = append(ids, pid)
|
|
rows, err := db.Query(
|
|
fmt.Sprintf("SELECT %s FROM %s WHERE %s = ?", pkCol, table, parentCol),
|
|
pid,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var childID string
|
|
if err := rows.Scan(&childID); err != nil {
|
|
return err
|
|
}
|
|
if err := collect(childID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
if err := collect(id); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Delete in reverse order (children first)
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
for i := len(ids) - 1; i >= 0; i-- {
|
|
if _, err := tx.Exec(
|
|
fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol),
|
|
ids[i],
|
|
); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// encryptField encrypts string and []byte fields, passes others through
|
|
// ID columns (ending in _id) are NOT encrypted
|
|
func encryptField(v reflect.Value, column string) any {
|
|
switch v.Kind() {
|
|
case reflect.String:
|
|
s := v.String()
|
|
if s == "" {
|
|
return ""
|
|
}
|
|
// Don't encrypt ID columns or known plain-text columns
|
|
plainCols := map[string]bool{"language": true, "timezone": true, "weight_unit": true, "height_unit": true}
|
|
if strings.HasSuffix(column, "_id") || plainCols[column] {
|
|
return s
|
|
}
|
|
return CryptoEncrypt(s)
|
|
|
|
case reflect.Slice:
|
|
if v.Type().Elem().Kind() == reflect.Uint8 {
|
|
// []byte
|
|
b := v.Bytes()
|
|
if len(b) == 0 {
|
|
return []byte{}
|
|
}
|
|
return CryptoEncryptBytes(b)
|
|
}
|
|
// Other slices: JSON encode then encrypt
|
|
return v.Interface()
|
|
|
|
default:
|
|
return v.Interface()
|
|
}
|
|
}
|
|
|
|
// createScanDest creates an appropriate scan destination for a Go type
|
|
func createScanDest(t reflect.Type) any {
|
|
switch t.Kind() {
|
|
case reflect.Int, reflect.Int64:
|
|
return new(sql.NullInt64)
|
|
case reflect.String:
|
|
return new(sql.NullString)
|
|
case reflect.Slice:
|
|
if t.Elem().Kind() == reflect.Uint8 {
|
|
return new([]byte)
|
|
}
|
|
return new(sql.NullString)
|
|
case reflect.Bool:
|
|
return new(sql.NullInt64)
|
|
default:
|
|
return new(sql.NullString)
|
|
}
|
|
}
|
|
|
|
// decryptAndSet decrypts the scanned value and sets the struct field
|
|
// ID columns (ending in _id) are NOT decrypted
|
|
func decryptAndSet(field reflect.Value, scanned any, t reflect.Type, column string) {
|
|
switch t.Kind() {
|
|
case reflect.Int, reflect.Int64:
|
|
if ns, ok := scanned.(*sql.NullInt64); ok && ns.Valid {
|
|
field.SetInt(ns.Int64)
|
|
}
|
|
|
|
case reflect.String:
|
|
if ns, ok := scanned.(*sql.NullString); ok && ns.Valid {
|
|
// Don't decrypt ID columns or known plain-text columns
|
|
plainCols := map[string]bool{"language": true, "timezone": true, "weight_unit": true, "height_unit": true}
|
|
if strings.HasSuffix(column, "_id") || plainCols[column] {
|
|
field.SetString(ns.String)
|
|
} else {
|
|
field.SetString(CryptoDecrypt(ns.String))
|
|
}
|
|
}
|
|
|
|
case reflect.Slice:
|
|
if t.Elem().Kind() == reflect.Uint8 {
|
|
// []byte - decrypt
|
|
if b, ok := scanned.(*[]byte); ok && b != nil && len(*b) > 0 {
|
|
decrypted, err := CryptoDecryptBytes(*b)
|
|
if err == nil {
|
|
field.SetBytes(decrypted)
|
|
}
|
|
}
|
|
}
|
|
|
|
case reflect.Bool:
|
|
if ns, ok := scanned.(*sql.NullInt64); ok && ns.Valid {
|
|
field.SetBool(ns.Int64 != 0)
|
|
}
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// OAuth 2.0 Store Functions
|
|
// =============================================================================
|
|
|
|
// OAuthClient represents a registered OAuth client (Claude, Flutter, etc.)
|
|
type OAuthClient struct {
|
|
ClientID string `db:"client_id,pk"`
|
|
ClientSecret string `db:"client_secret"` // hashed
|
|
Name string `db:"name"`
|
|
RedirectURIs []string `db:"-"` // parsed from JSON
|
|
RedirectJSON string `db:"redirect_uris"`
|
|
CreatedAt int64 `db:"created_at"`
|
|
}
|
|
|
|
// OAuthCode represents a temporary authorization code
|
|
type OAuthCode struct {
|
|
Code string `db:"code,pk"`
|
|
ClientID string `db:"client_id"`
|
|
DossierID string `db:"dossier_id"`
|
|
RedirectURI string `db:"redirect_uri"`
|
|
CodeChallenge string `db:"code_challenge"`
|
|
CodeChallengeMethod string `db:"code_challenge_method"`
|
|
ExpiresAt int64 `db:"expires_at"`
|
|
Used int `db:"used"`
|
|
}
|
|
|
|
// OAuthRefreshToken represents a long-lived refresh token
|
|
type OAuthRefreshToken struct {
|
|
TokenID string `db:"token_id,pk"`
|
|
ClientID string `db:"client_id"`
|
|
DossierID string `db:"dossier_id"`
|
|
ExpiresAt int64 `db:"expires_at"`
|
|
Revoked int `db:"revoked"`
|
|
CreatedAt int64 `db:"created_at"`
|
|
}
|
|
|
|
// generateID creates a random hex ID of specified byte length
|
|
func generateID(bytes int) string {
|
|
b := make([]byte, bytes)
|
|
rand.Read(b)
|
|
return hex.EncodeToString(b)
|
|
}
|
|
|
|
// hashSecret hashes a client secret for storage
|
|
func hashSecret(secret string) string {
|
|
h := sha256.Sum256([]byte(secret))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// OAuthClientCreate creates a new OAuth client and returns the plain secret
|
|
func OAuthClientCreate(name string, redirectURIs []string) (*OAuthClient, string, error) {
|
|
clientID := generateID(16)
|
|
plainSecret := generateID(32)
|
|
|
|
urisJSON, _ := json.Marshal(redirectURIs)
|
|
|
|
client := &OAuthClient{
|
|
ClientID: clientID,
|
|
ClientSecret: hashSecret(plainSecret),
|
|
Name: name,
|
|
RedirectURIs: redirectURIs,
|
|
RedirectJSON: string(urisJSON),
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
if err := authSave("oauth_clients", client); err != nil {
|
|
return nil, "", err
|
|
}
|
|
|
|
return client, plainSecret, nil
|
|
}
|
|
|
|
// OAuthClientCreatePublic creates a public OAuth client with a fixed client_id (no secret)
|
|
// Used for first-party clients like the MCP bridge
|
|
func OAuthClientCreatePublic(clientID, name string) error {
|
|
urisJSON, _ := json.Marshal([]string{})
|
|
|
|
client := &OAuthClient{
|
|
ClientID: clientID,
|
|
ClientSecret: "", // Public client - no secret
|
|
Name: name,
|
|
RedirectURIs: []string{},
|
|
RedirectJSON: string(urisJSON),
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
return authSave("oauth_clients", client)
|
|
}
|
|
|
|
// OAuthClientGet retrieves a client by ID
|
|
func OAuthClientGet(clientID string) (*OAuthClient, error) {
|
|
var client OAuthClient
|
|
if err := authLoad("oauth_clients", clientID, &client); err != nil {
|
|
return nil, err
|
|
}
|
|
// Parse redirect URIs
|
|
json.Unmarshal([]byte(client.RedirectJSON), &client.RedirectURIs)
|
|
return &client, nil
|
|
}
|
|
|
|
// OAuthClientVerifySecret checks if the provided secret matches
|
|
func OAuthClientVerifySecret(client *OAuthClient, secret string) bool {
|
|
return client.ClientSecret == hashSecret(secret)
|
|
}
|
|
|
|
// OAuthClientValidRedirectURI checks if URI is registered for client
|
|
func OAuthClientValidRedirectURI(client *OAuthClient, uri string) bool {
|
|
for _, u := range client.RedirectURIs {
|
|
if u == uri {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// OAuthCodeCreate creates a new authorization code (valid for 10 minutes)
|
|
func OAuthCodeCreate(clientID, dossierID, redirectURI, codeChallenge, codeChallengeMethod string) (*OAuthCode, error) {
|
|
code := &OAuthCode{
|
|
Code: generateID(32),
|
|
ClientID: clientID,
|
|
DossierID: dossierID,
|
|
RedirectURI: redirectURI,
|
|
CodeChallenge: codeChallenge,
|
|
CodeChallengeMethod: codeChallengeMethod,
|
|
ExpiresAt: time.Now().Unix() + 600, // 10 minutes
|
|
Used: 0,
|
|
}
|
|
|
|
if err := authSave("oauth_codes", code); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return code, nil
|
|
}
|
|
|
|
// OAuthCodeGet retrieves and validates a code (checks expiry, not used)
|
|
func OAuthCodeGet(code string) (*OAuthCode, error) {
|
|
var c OAuthCode
|
|
if err := authLoad("oauth_codes", code, &c); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if c.Used != 0 {
|
|
return nil, fmt.Errorf("code already used")
|
|
}
|
|
|
|
if time.Now().Unix() > c.ExpiresAt {
|
|
return nil, fmt.Errorf("code expired")
|
|
}
|
|
|
|
return &c, nil
|
|
}
|
|
|
|
// OAuthCodeUse marks a code as used (single-use)
|
|
func OAuthCodeUse(code string) error {
|
|
var c OAuthCode
|
|
if err := authLoad("oauth_codes", code, &c); err != nil {
|
|
return err
|
|
}
|
|
c.Used = 1
|
|
return authSave("oauth_codes", &c)
|
|
}
|
|
|
|
// OAuthCodeVerifyPKCE verifies the PKCE code_verifier against stored challenge
|
|
func OAuthCodeVerifyPKCE(c *OAuthCode, codeVerifier string) bool {
|
|
if c.CodeChallenge == "" {
|
|
return true // No PKCE required
|
|
}
|
|
|
|
if c.CodeChallengeMethod != "S256" {
|
|
return false // Only support S256
|
|
}
|
|
|
|
// S256: BASE64URL(SHA256(code_verifier)) == code_challenge
|
|
h := sha256.Sum256([]byte(codeVerifier))
|
|
computed := base64.RawURLEncoding.EncodeToString(h[:])
|
|
|
|
return computed == c.CodeChallenge
|
|
}
|
|
|
|
// OAuthRefreshTokenCreate creates a new refresh token (valid for 30 days)
|
|
func OAuthRefreshTokenCreate(clientID, dossierID string) (*OAuthRefreshToken, error) {
|
|
token := &OAuthRefreshToken{
|
|
TokenID: generateID(32),
|
|
ClientID: clientID,
|
|
DossierID: dossierID,
|
|
ExpiresAt: time.Now().Unix() + 30*24*60*60, // 30 days
|
|
Revoked: 0,
|
|
CreatedAt: time.Now().Unix(),
|
|
}
|
|
|
|
if err := authSave("oauth_refresh_tokens", token); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return token, nil
|
|
}
|
|
|
|
// OAuthRefreshTokenGet retrieves and validates a refresh token
|
|
func OAuthRefreshTokenGet(tokenID string) (*OAuthRefreshToken, error) {
|
|
var t OAuthRefreshToken
|
|
if err := authLoad("oauth_refresh_tokens", tokenID, &t); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if t.Revoked != 0 {
|
|
return nil, fmt.Errorf("token revoked")
|
|
}
|
|
|
|
if time.Now().Unix() > t.ExpiresAt {
|
|
return nil, fmt.Errorf("token expired")
|
|
}
|
|
|
|
return &t, nil
|
|
}
|
|
|
|
// OAuthRefreshTokenRevoke revokes a refresh token
|
|
func OAuthRefreshTokenRevoke(tokenID string) error {
|
|
var t OAuthRefreshToken
|
|
if err := authLoad("oauth_refresh_tokens", tokenID, &t); err != nil {
|
|
return err
|
|
}
|
|
t.Revoked = 1
|
|
return authSave("oauth_refresh_tokens", &t)
|
|
}
|
|
|
|
// OAuthRefreshTokenRotate revokes old token and creates new one
|
|
func OAuthRefreshTokenRotate(oldTokenID string) (*OAuthRefreshToken, error) {
|
|
old, err := OAuthRefreshTokenGet(oldTokenID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Revoke old
|
|
if err := OAuthRefreshTokenRevoke(oldTokenID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Create new
|
|
return OAuthRefreshTokenCreate(old.ClientID, old.DossierID)
|
|
}
|
|
|
|
// OAuthRefreshTokenRevokeAll revokes all refresh tokens for a dossier
|
|
func OAuthRefreshTokenRevokeAll(dossierID string) error {
|
|
var tokens []*OAuthRefreshToken
|
|
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE dossier_id = ? AND revoked = 0", []any{dossierID}, &tokens); err != nil {
|
|
return err
|
|
}
|
|
for _, t := range tokens {
|
|
t.Revoked = 1
|
|
if err := authSave("oauth_refresh_tokens", t); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// OAuthRefreshTokenGetForClient gets an existing valid refresh token for a client/dossier pair
|
|
func OAuthRefreshTokenGetForClient(clientID, dossierID string) (*OAuthRefreshToken, error) {
|
|
var tokens []*OAuthRefreshToken
|
|
now := time.Now().Unix()
|
|
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE client_id = ? AND dossier_id = ? AND revoked = 0 AND expires_at > ? ORDER BY created_at DESC LIMIT 1", []any{clientID, dossierID, now}, &tokens); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(tokens) == 0 {
|
|
return nil, fmt.Errorf("no token found")
|
|
}
|
|
return tokens[0], nil
|
|
}
|
|
|
|
// OAuthRefreshTokenGetOrCreate gets or creates a refresh token for a client/dossier pair
|
|
func OAuthRefreshTokenGetOrCreate(clientID, dossierID string) (*OAuthRefreshToken, error) {
|
|
token, err := OAuthRefreshTokenGetForClient(clientID, dossierID)
|
|
if err == nil {
|
|
return token, nil
|
|
}
|
|
return OAuthRefreshTokenCreate(clientID, dossierID)
|
|
}
|
|
|
|
// OAuthRefreshTokenRegenerate revokes all existing tokens for a client/dossier and creates new one
|
|
func OAuthRefreshTokenRegenerate(clientID, dossierID string) (*OAuthRefreshToken, error) {
|
|
var tokens []*OAuthRefreshToken
|
|
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE client_id = ? AND dossier_id = ? AND revoked = 0", []any{clientID, dossierID}, &tokens); err != nil {
|
|
return nil, err
|
|
}
|
|
for _, t := range tokens {
|
|
t.Revoked = 1
|
|
authSave("oauth_refresh_tokens", t)
|
|
}
|
|
return OAuthRefreshTokenCreate(clientID, dossierID)
|
|
}
|
|
|
|
// OAuthCleanup removes expired codes and tokens (call periodically)
|
|
func OAuthCleanup() error {
|
|
now := time.Now().Unix()
|
|
|
|
// Delete expired or used codes (keep used codes 1 hour for debugging)
|
|
var codes []*OAuthCode
|
|
if err := authQuery("SELECT * FROM oauth_codes WHERE expires_at < ? OR used = 1", []any{now - 3600}, &codes); err == nil {
|
|
for _, c := range codes {
|
|
authDelete("oauth_codes", "code", c.Code)
|
|
}
|
|
}
|
|
|
|
// Delete expired revoked tokens (keep revoked 1 day)
|
|
var tokens []*OAuthRefreshToken
|
|
if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE expires_at < ? AND revoked = 1", []any{now - 86400}, &tokens); err == nil {
|
|
for _, t := range tokens {
|
|
authDelete("oauth_refresh_tokens", "token_id", t.TokenID)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ============================================================================
|
|
// Reference Database Queries (lab_test, lab_reference)
|
|
// ============================================================================
|
|
|
|
// refQuery queries the reference database (read-only reference data)
|
|
func refQuery(query string, args []any, slicePtr any) error {
|
|
start := time.Now()
|
|
defer func() { logSlowQuery(query, time.Since(start), args...) }()
|
|
|
|
sliceVal := reflect.ValueOf(slicePtr)
|
|
if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice {
|
|
return fmt.Errorf("Query requires a pointer to slice")
|
|
}
|
|
|
|
sliceType := sliceVal.Elem().Type()
|
|
elemType := sliceType.Elem()
|
|
if elemType.Kind() == reflect.Ptr {
|
|
elemType = elemType.Elem()
|
|
}
|
|
|
|
// Get struct field info
|
|
sample := reflect.New(elemType).Interface()
|
|
info, err := getTableInfo("", sample)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Build column->field mapping
|
|
colToField := make(map[string]*fieldInfo)
|
|
for i := range info.Fields {
|
|
colToField[info.Fields[i].Column] = &info.Fields[i]
|
|
}
|
|
|
|
rows, err := refDB.Query(query, args...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
cols, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result := reflect.MakeSlice(sliceType, 0, 0)
|
|
|
|
for rows.Next() {
|
|
item := reflect.New(elemType)
|
|
|
|
// Prepare scan destinations
|
|
scanDest := make([]any, len(cols))
|
|
for i, col := range cols {
|
|
fi := colToField[col]
|
|
if fi == nil {
|
|
var dummy any
|
|
scanDest[i] = &dummy
|
|
continue
|
|
}
|
|
|
|
switch fi.Type.Kind() {
|
|
case reflect.String:
|
|
var s sql.NullString
|
|
scanDest[i] = &s
|
|
case reflect.Int, reflect.Int64:
|
|
var n sql.NullInt64
|
|
scanDest[i] = &n
|
|
case reflect.Bool:
|
|
var b sql.NullBool
|
|
scanDest[i] = &b
|
|
default:
|
|
var dummy any
|
|
scanDest[i] = &dummy
|
|
}
|
|
}
|
|
|
|
if err := rows.Scan(scanDest...); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Map values to struct fields
|
|
for i, col := range cols {
|
|
fi := colToField[col]
|
|
if fi == nil {
|
|
continue
|
|
}
|
|
|
|
field := item.Elem().Field(fi.Index)
|
|
switch fi.Type.Kind() {
|
|
case reflect.String:
|
|
ns := scanDest[i].(*sql.NullString)
|
|
if ns.Valid {
|
|
field.SetString(ns.String)
|
|
}
|
|
case reflect.Int, reflect.Int64:
|
|
ni := scanDest[i].(*sql.NullInt64)
|
|
if ni.Valid {
|
|
field.SetInt(ni.Int64)
|
|
}
|
|
case reflect.Bool:
|
|
nb := scanDest[i].(*sql.NullBool)
|
|
if nb.Valid {
|
|
field.SetBool(nb.Bool)
|
|
}
|
|
}
|
|
}
|
|
|
|
if sliceType.Elem().Kind() == reflect.Ptr {
|
|
result = reflect.Append(result, item)
|
|
} else {
|
|
result = reflect.Append(result, item.Elem())
|
|
}
|
|
}
|
|
|
|
sliceVal.Elem().Set(result)
|
|
return rows.Err()
|
|
}
|
|
|
|
// refSave saves to reference database (for import tools)
|
|
func refSave(table string, v any) error {
|
|
val := reflect.ValueOf(v)
|
|
if val.Kind() == reflect.Ptr {
|
|
val = val.Elem()
|
|
}
|
|
|
|
// Handle slice
|
|
if val.Kind() == reflect.Slice {
|
|
for i := 0; i < val.Len(); i++ {
|
|
item := val.Index(i)
|
|
if item.Kind() == reflect.Ptr {
|
|
item = item.Elem()
|
|
}
|
|
if err := refSave(table, item.Addr().Interface()); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Single struct
|
|
info, err := getTableInfo(table, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var cols []string
|
|
var vals []any
|
|
var placeholders []string
|
|
|
|
for _, fi := range info.Fields {
|
|
field := val.Field(fi.Index)
|
|
|
|
cols = append(cols, fi.Column)
|
|
placeholders = append(placeholders, "?")
|
|
|
|
switch fi.Type.Kind() {
|
|
case reflect.String:
|
|
vals = append(vals, field.String())
|
|
case reflect.Int, reflect.Int64:
|
|
vals = append(vals, field.Int())
|
|
case reflect.Bool:
|
|
v := 0
|
|
if field.Bool() {
|
|
v = 1
|
|
}
|
|
vals = append(vals, v)
|
|
default:
|
|
vals = append(vals, nil)
|
|
}
|
|
}
|
|
|
|
query := fmt.Sprintf("INSERT OR REPLACE INTO %s (%s) VALUES (%s)",
|
|
table, strings.Join(cols, ", "), strings.Join(placeholders, ", "))
|
|
|
|
start := time.Now()
|
|
defer func() { logSlowQuery(query, time.Since(start), vals...) }()
|
|
|
|
_, err = refDB.Exec(query, vals...)
|
|
return err
|
|
}
|
|
|
|
// refDelete deletes from reference database
|
|
func refDelete(table, pkCol, pkVal string) error {
|
|
query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol)
|
|
start := time.Now()
|
|
defer func() { logSlowQuery(query, time.Since(start), pkVal) }()
|
|
|
|
_, err := refDB.Exec(query, pkVal)
|
|
return err
|
|
}
|