package lib // ============================================================================ // ⛔ CRITICAL: DO NOT MODIFY THIS FILE WITHOUT JOHAN'S EXPRESS CONSENT // ============================================================================ // This is the ONLY file allowed to access the database directly. // Internal DB functions (unexported): dbSave, dbLoad, dbQuery, dbDelete, dbCount // External code must use RBAC-checked functions (EntryWrite, DossierGet, etc.) // // Run `make check-db` to verify no direct DB access exists elsewhere. // ============================================================================ import ( "crypto/rand" "crypto/sha256" "database/sql" "encoding/base64" "encoding/hex" "encoding/json" "fmt" "reflect" "strings" "time" ) // Store provides a generic data layer with automatic encryption. // String and []byte fields are encrypted automatically. // Int/int64/bool fields pass through unchanged. // // Struct tags: // db:"column_name" - maps field to column (default: lowercase field name) // db:"column_name,pk" - marks as primary key // db:"-" - skip this field // // Example: // type Entry struct { // EntryID int64 `db:"entry_id,pk"` // Type string `db:"type"` // auto-encrypted // Status int `db:"status"` // not encrypted // } // fieldInfo holds metadata about a struct field type fieldInfo struct { Name string // struct field name Column string // db column name Type reflect.Type // Go type IsPK bool // is primary key Skip bool // skip this field Index int // field index in struct } // tableInfo holds metadata about a struct/table mapping type tableInfo struct { Name string Fields []fieldInfo PK *fieldInfo } // getTableInfo extracts table metadata from a struct using reflection func getTableInfo(table string, v any) (*tableInfo, error) { t := reflect.TypeOf(v) if t.Kind() == reflect.Ptr { t = t.Elem() } if t.Kind() != reflect.Struct { return nil, fmt.Errorf("expected struct, got %s", t.Kind()) } info := &tableInfo{Name: table} for i := 0; i < t.NumField(); i++ { field := t.Field(i) // Parse db tag tag := field.Tag.Get("db") if tag == "-" { continue } fi := fieldInfo{ Name: field.Name, Type: field.Type, Index: i, } // Default: lowercase field name fi.Column = strings.ToLower(field.Name) // Parse tag options (override column name if specified) if tag != "" { parts := strings.Split(tag, ",") if parts[0] != "" { fi.Column = parts[0] } for _, opt := range parts[1:] { if opt == "pk" { fi.IsPK = true } } } info.Fields = append(info.Fields, fi) if fi.IsPK { info.PK = &info.Fields[len(info.Fields)-1] } } // PK required for table operations, optional for QuerySQL (table="") if table != "" && info.PK == nil { return nil, fmt.Errorf("no primary key defined for table %s (use `db:\"col,pk\"`)", table) } return info, nil } // goTypeToSQLite maps Go types to SQLite column types func goTypeToSQLite(t reflect.Type) string { switch t.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return "INTEGER" case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return "INTEGER" case reflect.Float32, reflect.Float64: return "REAL" case reflect.Bool: return "INTEGER" case reflect.String: return "TEXT" case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { return "BLOB" } return "TEXT" // JSON-encode other slices default: return "TEXT" } } // Verify checks that the database schema matches the struct definition. // Returns nil if OK, or an error with suggested ALTER statements. func Verify(table string, v any) error { info, err := getTableInfo(table, v) if err != nil { return err } // Get current columns from DB rows, err := db.Query(fmt.Sprintf("PRAGMA table_info(%s)", table)) if err != nil { return fmt.Errorf("failed to query table info: %w", err) } defer rows.Close() dbColumns := make(map[string]string) // column -> type for rows.Next() { var cid int var name, colType string var notNull, pk int var dflt sql.NullString if err := rows.Scan(&cid, &name, &colType, ¬Null, &dflt, &pk); err != nil { return err } dbColumns[name] = colType } // Check if table exists if len(dbColumns) == 0 { // Generate CREATE TABLE var cols []string for _, f := range info.Fields { colDef := fmt.Sprintf("%s %s", f.Column, goTypeToSQLite(f.Type)) if f.IsPK { colDef += " PRIMARY KEY" } cols = append(cols, colDef) } return fmt.Errorf(`schema mismatch: table '%s' does not exist Suggested fix: CREATE TABLE %s ( %s ); `, table, table, strings.Join(cols, ",\n ")) } // Check for missing columns var missing []fieldInfo for _, f := range info.Fields { if _, exists := dbColumns[f.Column]; !exists { missing = append(missing, f) } } if len(missing) > 0 { var alters []string var details []string for _, f := range missing { sqlType := goTypeToSQLite(f.Type) details = append(details, fmt.Sprintf(" - %s (%s)", f.Column, sqlType)) alters = append(alters, fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s;", table, f.Column, sqlType)) } return fmt.Errorf(`schema mismatch for table '%s' Missing columns: %s Suggested fix: %s Run these manually, then restart. `, table, strings.Join(details, "\n"), strings.Join(alters, "\n ")) } return nil } // VerifyAll checks multiple table/struct pairs func VerifyAll(pairs ...any) error { if len(pairs)%2 != 0 { return fmt.Errorf("VerifyAll requires pairs of (table string, struct)") } for i := 0; i < len(pairs); i += 2 { table, ok := pairs[i].(string) if !ok { return fmt.Errorf("expected string for table name at position %d", i) } if err := Verify(table, pairs[i+1]); err != nil { return err } } return nil } // dbSave upserts struct(s) to the database. // Accepts a single struct or a slice of structs. // String and []byte fields are encrypted automatically. // Slices are wrapped in a transaction for atomicity. func dbSave(table string, v any) error { start := time.Now() defer func() { logSlowQuery("INSERT OR REPLACE INTO "+table, time.Since(start)) }() val := reflect.ValueOf(v) if val.Kind() == reflect.Ptr { val = val.Elem() } // Slice: bulk upsert with transaction if val.Kind() == reflect.Slice { if val.Len() == 0 { return nil } // Get table info from first element elemType := val.Type().Elem() if elemType.Kind() == reflect.Ptr { elemType = elemType.Elem() } sample := reflect.New(elemType).Interface() info, err := getTableInfo(table, sample) if err != nil { return err } // Build query once var columns []string var placeholders []string for _, f := range info.Fields { columns = append(columns, f.Column) placeholders = append(placeholders, "?") } query := fmt.Sprintf( "INSERT OR REPLACE INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "), ) // Transaction with prepared statement tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() stmt, err := tx.Prepare(query) if err != nil { return err } defer stmt.Close() for i := 0; i < val.Len(); i++ { elem := val.Index(i) if elem.Kind() == reflect.Ptr { elem = elem.Elem() } values := make([]any, len(info.Fields)) for j, f := range info.Fields { values[j] = encryptField(elem.Field(f.Index), f.Column) } if _, err := stmt.Exec(values...); err != nil { return err } } return tx.Commit() } // Single struct info, err := getTableInfo(table, v) if err != nil { return err } var columns []string var placeholders []string var values []any for _, f := range info.Fields { columns = append(columns, f.Column) placeholders = append(placeholders, "?") values = append(values, encryptField(val.Field(f.Index), f.Column)) } query := fmt.Sprintf( "INSERT OR REPLACE INTO %s (%s) VALUES (%s)", table, strings.Join(columns, ", "), strings.Join(placeholders, ", "), ) _, err = db.Exec(query, values...) return err } // dbLoad retrieves a record by primary key and populates the struct. // String and []byte fields are decrypted automatically. func dbLoad(table string, id string, v any) error { start := time.Now() defer func() { logSlowQuery("SELECT FROM "+table+" WHERE pk=?", time.Since(start), id) }() info, err := getTableInfo(table, v) if err != nil { return err } val := reflect.ValueOf(v) if val.Kind() != reflect.Ptr { return fmt.Errorf("Load requires a pointer to struct") } val = val.Elem() // Build SELECT var columns []string for _, f := range info.Fields { columns = append(columns, f.Column) } query := fmt.Sprintf( "SELECT %s FROM %s WHERE %s = ?", strings.Join(columns, ", "), table, info.PK.Column, ) row := db.QueryRow(query, id) // Create scan destinations scanDest := make([]any, len(info.Fields)) for i, f := range info.Fields { scanDest[i] = createScanDest(f.Type) } if err := row.Scan(scanDest...); err != nil { return err } // Decrypt and set fields for i, f := range info.Fields { decryptAndSet(val.Field(f.Index), scanDest[i], f.Type, f.Column) } return nil } // dbQuery runs a SQL query and populates the slice. // Column names in result must match struct db tags. // String and []byte fields are decrypted automatically. func dbQuery(query string, args []any, slicePtr any) error { start := time.Now() defer func() { logSlowQuery(query, time.Since(start), args...) }() sliceVal := reflect.ValueOf(slicePtr) if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice { return fmt.Errorf("Query requires a pointer to slice") } sliceType := sliceVal.Elem().Type() elemType := sliceType.Elem() if elemType.Kind() == reflect.Ptr { elemType = elemType.Elem() } // Get struct field info sample := reflect.New(elemType).Interface() info, err := getTableInfo("", sample) if err != nil { return err } // Build column->field mapping colToField := make(map[string]*fieldInfo) for i := range info.Fields { colToField[info.Fields[i].Column] = &info.Fields[i] } rows, err := db.Query(query, args...) if err != nil { return err } defer rows.Close() // Get column names from result cols, err := rows.Columns() if err != nil { return err } result := reflect.MakeSlice(sliceType, 0, 0) for rows.Next() { elem := reflect.New(elemType).Elem() // Create scan destinations for each column scanDest := make([]any, len(cols)) fieldMap := make([]*fieldInfo, len(cols)) for i, col := range cols { if fi, ok := colToField[col]; ok { scanDest[i] = createScanDest(fi.Type) fieldMap[i] = fi } else { // Unknown column - scan into throwaway scanDest[i] = new(any) } } if err := rows.Scan(scanDest...); err != nil { return err } // Decrypt and set matched fields for i, fi := range fieldMap { if fi != nil { decryptAndSet(elem.Field(fi.Index), scanDest[i], fi.Type, fi.Column) } } if sliceType.Elem().Kind() == reflect.Ptr { result = reflect.Append(result, elem.Addr()) } else { result = reflect.Append(result, elem) } } sliceVal.Elem().Set(result) return nil } // dbCount runs a SELECT COUNT(*) query and returns the result. func dbCount(query string, args ...any) (int, error) { var count int err := db.QueryRow(query, args...).Scan(&count) return count, err } // dbDelete removes a record by primary key. // pkCol is the primary key column name, id is 16-char hex string. func dbDelete(table, pkCol, id string) error { query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol) _, err := db.Exec(query, id) return err } // dbDeleteTree removes a record and all its descendants. // Traverses the parent-child hierarchy recursively, deletes children first. // Works with any SQL database (no CTEs or CASCADE needed). func dbDeleteTree(table, pkCol, parentCol, id string) error { // Collect all IDs (parent + descendants) var ids []string var collect func(string) error collect = func(pid string) error { ids = append(ids, pid) rows, err := db.Query( fmt.Sprintf("SELECT %s FROM %s WHERE %s = ?", pkCol, table, parentCol), pid, ) if err != nil { return err } defer rows.Close() for rows.Next() { var childID string if err := rows.Scan(&childID); err != nil { return err } if err := collect(childID); err != nil { return err } } return nil } if err := collect(id); err != nil { return err } // Delete in reverse order (children first) tx, err := db.Begin() if err != nil { return err } defer tx.Rollback() for i := len(ids) - 1; i >= 0; i-- { if _, err := tx.Exec( fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol), ids[i], ); err != nil { return err } } return tx.Commit() } // encryptField encrypts string and []byte fields, passes others through // ID columns (ending in _id) are NOT encrypted func encryptField(v reflect.Value, column string) any { switch v.Kind() { case reflect.String: s := v.String() if s == "" { return "" } // Don't encrypt ID columns or known plain-text columns plainCols := map[string]bool{"language": true, "timezone": true, "weight_unit": true, "height_unit": true} if strings.HasSuffix(column, "_id") || plainCols[column] { return s } return CryptoEncrypt(s) case reflect.Slice: if v.Type().Elem().Kind() == reflect.Uint8 { // []byte b := v.Bytes() if len(b) == 0 { return []byte{} } return CryptoEncryptBytes(b) } // Other slices: JSON encode then encrypt return v.Interface() default: return v.Interface() } } // createScanDest creates an appropriate scan destination for a Go type func createScanDest(t reflect.Type) any { switch t.Kind() { case reflect.Int, reflect.Int64: return new(sql.NullInt64) case reflect.String: return new(sql.NullString) case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { return new([]byte) } return new(sql.NullString) case reflect.Bool: return new(sql.NullInt64) default: return new(sql.NullString) } } // decryptAndSet decrypts the scanned value and sets the struct field // ID columns (ending in _id) are NOT decrypted func decryptAndSet(field reflect.Value, scanned any, t reflect.Type, column string) { switch t.Kind() { case reflect.Int, reflect.Int64: if ns, ok := scanned.(*sql.NullInt64); ok && ns.Valid { field.SetInt(ns.Int64) } case reflect.String: if ns, ok := scanned.(*sql.NullString); ok && ns.Valid { // Don't decrypt ID columns or known plain-text columns plainCols := map[string]bool{"language": true, "timezone": true, "weight_unit": true, "height_unit": true} if strings.HasSuffix(column, "_id") || plainCols[column] { field.SetString(ns.String) } else { field.SetString(CryptoDecrypt(ns.String)) } } case reflect.Slice: if t.Elem().Kind() == reflect.Uint8 { // []byte - decrypt if b, ok := scanned.(*[]byte); ok && b != nil && len(*b) > 0 { decrypted, err := CryptoDecryptBytes(*b) if err == nil { field.SetBytes(decrypted) } } } case reflect.Bool: if ns, ok := scanned.(*sql.NullInt64); ok && ns.Valid { field.SetBool(ns.Int64 != 0) } } } // ============================================================================= // OAuth 2.0 Store Functions // ============================================================================= // OAuthClient represents a registered OAuth client (Claude, Flutter, etc.) type OAuthClient struct { ClientID string `db:"client_id,pk"` ClientSecret string `db:"client_secret"` // hashed Name string `db:"name"` RedirectURIs []string `db:"-"` // parsed from JSON RedirectJSON string `db:"redirect_uris"` CreatedAt int64 `db:"created_at"` } // OAuthCode represents a temporary authorization code type OAuthCode struct { Code string `db:"code,pk"` ClientID string `db:"client_id"` DossierID string `db:"dossier_id"` RedirectURI string `db:"redirect_uri"` CodeChallenge string `db:"code_challenge"` CodeChallengeMethod string `db:"code_challenge_method"` ExpiresAt int64 `db:"expires_at"` Used int `db:"used"` } // OAuthRefreshToken represents a long-lived refresh token type OAuthRefreshToken struct { TokenID string `db:"token_id,pk"` ClientID string `db:"client_id"` DossierID string `db:"dossier_id"` ExpiresAt int64 `db:"expires_at"` Revoked int `db:"revoked"` CreatedAt int64 `db:"created_at"` } // generateID creates a random hex ID of specified byte length func generateID(bytes int) string { b := make([]byte, bytes) rand.Read(b) return hex.EncodeToString(b) } // hashSecret hashes a client secret for storage func hashSecret(secret string) string { h := sha256.Sum256([]byte(secret)) return hex.EncodeToString(h[:]) } // OAuthClientCreate creates a new OAuth client and returns the plain secret func OAuthClientCreate(name string, redirectURIs []string) (*OAuthClient, string, error) { clientID := generateID(16) plainSecret := generateID(32) urisJSON, _ := json.Marshal(redirectURIs) client := &OAuthClient{ ClientID: clientID, ClientSecret: hashSecret(plainSecret), Name: name, RedirectURIs: redirectURIs, RedirectJSON: string(urisJSON), CreatedAt: time.Now().Unix(), } if err := authSave("oauth_clients", client); err != nil { return nil, "", err } return client, plainSecret, nil } // OAuthClientCreatePublic creates a public OAuth client with a fixed client_id (no secret) // Used for first-party clients like the MCP bridge func OAuthClientCreatePublic(clientID, name string) error { urisJSON, _ := json.Marshal([]string{}) client := &OAuthClient{ ClientID: clientID, ClientSecret: "", // Public client - no secret Name: name, RedirectURIs: []string{}, RedirectJSON: string(urisJSON), CreatedAt: time.Now().Unix(), } return authSave("oauth_clients", client) } // OAuthClientGet retrieves a client by ID func OAuthClientGet(clientID string) (*OAuthClient, error) { var client OAuthClient if err := authLoad("oauth_clients", clientID, &client); err != nil { return nil, err } // Parse redirect URIs json.Unmarshal([]byte(client.RedirectJSON), &client.RedirectURIs) return &client, nil } // OAuthClientVerifySecret checks if the provided secret matches func OAuthClientVerifySecret(client *OAuthClient, secret string) bool { return client.ClientSecret == hashSecret(secret) } // OAuthClientValidRedirectURI checks if URI is registered for client func OAuthClientValidRedirectURI(client *OAuthClient, uri string) bool { for _, u := range client.RedirectURIs { if u == uri { return true } } return false } // OAuthCodeCreate creates a new authorization code (valid for 10 minutes) func OAuthCodeCreate(clientID, dossierID, redirectURI, codeChallenge, codeChallengeMethod string) (*OAuthCode, error) { code := &OAuthCode{ Code: generateID(32), ClientID: clientID, DossierID: dossierID, RedirectURI: redirectURI, CodeChallenge: codeChallenge, CodeChallengeMethod: codeChallengeMethod, ExpiresAt: time.Now().Unix() + 600, // 10 minutes Used: 0, } if err := authSave("oauth_codes", code); err != nil { return nil, err } return code, nil } // OAuthCodeGet retrieves and validates a code (checks expiry, not used) func OAuthCodeGet(code string) (*OAuthCode, error) { var c OAuthCode if err := authLoad("oauth_codes", code, &c); err != nil { return nil, err } if c.Used != 0 { return nil, fmt.Errorf("code already used") } if time.Now().Unix() > c.ExpiresAt { return nil, fmt.Errorf("code expired") } return &c, nil } // OAuthCodeUse marks a code as used (single-use) func OAuthCodeUse(code string) error { var c OAuthCode if err := authLoad("oauth_codes", code, &c); err != nil { return err } c.Used = 1 return authSave("oauth_codes", &c) } // OAuthCodeVerifyPKCE verifies the PKCE code_verifier against stored challenge func OAuthCodeVerifyPKCE(c *OAuthCode, codeVerifier string) bool { if c.CodeChallenge == "" { return true // No PKCE required } if c.CodeChallengeMethod != "S256" { return false // Only support S256 } // S256: BASE64URL(SHA256(code_verifier)) == code_challenge h := sha256.Sum256([]byte(codeVerifier)) computed := base64.RawURLEncoding.EncodeToString(h[:]) return computed == c.CodeChallenge } // OAuthRefreshTokenCreate creates a new refresh token (valid for 30 days) func OAuthRefreshTokenCreate(clientID, dossierID string) (*OAuthRefreshToken, error) { token := &OAuthRefreshToken{ TokenID: generateID(32), ClientID: clientID, DossierID: dossierID, ExpiresAt: time.Now().Unix() + 30*24*60*60, // 30 days Revoked: 0, CreatedAt: time.Now().Unix(), } if err := authSave("oauth_refresh_tokens", token); err != nil { return nil, err } return token, nil } // OAuthRefreshTokenGet retrieves and validates a refresh token func OAuthRefreshTokenGet(tokenID string) (*OAuthRefreshToken, error) { var t OAuthRefreshToken if err := authLoad("oauth_refresh_tokens", tokenID, &t); err != nil { return nil, err } if t.Revoked != 0 { return nil, fmt.Errorf("token revoked") } if time.Now().Unix() > t.ExpiresAt { return nil, fmt.Errorf("token expired") } return &t, nil } // OAuthRefreshTokenRevoke revokes a refresh token func OAuthRefreshTokenRevoke(tokenID string) error { var t OAuthRefreshToken if err := authLoad("oauth_refresh_tokens", tokenID, &t); err != nil { return err } t.Revoked = 1 return authSave("oauth_refresh_tokens", &t) } // OAuthRefreshTokenRotate revokes old token and creates new one func OAuthRefreshTokenRotate(oldTokenID string) (*OAuthRefreshToken, error) { old, err := OAuthRefreshTokenGet(oldTokenID) if err != nil { return nil, err } // Revoke old if err := OAuthRefreshTokenRevoke(oldTokenID); err != nil { return nil, err } // Create new return OAuthRefreshTokenCreate(old.ClientID, old.DossierID) } // OAuthRefreshTokenRevokeAll revokes all refresh tokens for a dossier func OAuthRefreshTokenRevokeAll(dossierID string) error { var tokens []*OAuthRefreshToken if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE dossier_id = ? AND revoked = 0", []any{dossierID}, &tokens); err != nil { return err } for _, t := range tokens { t.Revoked = 1 if err := authSave("oauth_refresh_tokens", t); err != nil { return err } } return nil } // OAuthRefreshTokenGetForClient gets an existing valid refresh token for a client/dossier pair func OAuthRefreshTokenGetForClient(clientID, dossierID string) (*OAuthRefreshToken, error) { var tokens []*OAuthRefreshToken now := time.Now().Unix() if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE client_id = ? AND dossier_id = ? AND revoked = 0 AND expires_at > ? ORDER BY created_at DESC LIMIT 1", []any{clientID, dossierID, now}, &tokens); err != nil { return nil, err } if len(tokens) == 0 { return nil, fmt.Errorf("no token found") } return tokens[0], nil } // OAuthRefreshTokenGetOrCreate gets or creates a refresh token for a client/dossier pair func OAuthRefreshTokenGetOrCreate(clientID, dossierID string) (*OAuthRefreshToken, error) { token, err := OAuthRefreshTokenGetForClient(clientID, dossierID) if err == nil { return token, nil } return OAuthRefreshTokenCreate(clientID, dossierID) } // OAuthRefreshTokenRegenerate revokes all existing tokens for a client/dossier and creates new one func OAuthRefreshTokenRegenerate(clientID, dossierID string) (*OAuthRefreshToken, error) { var tokens []*OAuthRefreshToken if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE client_id = ? AND dossier_id = ? AND revoked = 0", []any{clientID, dossierID}, &tokens); err != nil { return nil, err } for _, t := range tokens { t.Revoked = 1 authSave("oauth_refresh_tokens", t) } return OAuthRefreshTokenCreate(clientID, dossierID) } // OAuthCleanup removes expired codes and tokens (call periodically) func OAuthCleanup() error { now := time.Now().Unix() // Delete expired or used codes (keep used codes 1 hour for debugging) var codes []*OAuthCode if err := authQuery("SELECT * FROM oauth_codes WHERE expires_at < ? OR used = 1", []any{now - 3600}, &codes); err == nil { for _, c := range codes { authDelete("oauth_codes", "code", c.Code) } } // Delete expired revoked tokens (keep revoked 1 day) var tokens []*OAuthRefreshToken if err := authQuery("SELECT * FROM oauth_refresh_tokens WHERE expires_at < ? AND revoked = 1", []any{now - 86400}, &tokens); err == nil { for _, t := range tokens { authDelete("oauth_refresh_tokens", "token_id", t.TokenID) } } return nil } // ============================================================================ // Reference Database Queries (lab_test, lab_reference) // ============================================================================ // RefQuery queries the reference database (read-only reference data) func RefQuery(query string, args []any, slicePtr any) error { start := time.Now() defer func() { logSlowQuery(query, time.Since(start), args...) }() sliceVal := reflect.ValueOf(slicePtr) if sliceVal.Kind() != reflect.Ptr || sliceVal.Elem().Kind() != reflect.Slice { return fmt.Errorf("Query requires a pointer to slice") } sliceType := sliceVal.Elem().Type() elemType := sliceType.Elem() if elemType.Kind() == reflect.Ptr { elemType = elemType.Elem() } // Get struct field info sample := reflect.New(elemType).Interface() info, err := getTableInfo("", sample) if err != nil { return err } // Build column->field mapping colToField := make(map[string]*fieldInfo) for i := range info.Fields { colToField[info.Fields[i].Column] = &info.Fields[i] } rows, err := refDB.Query(query, args...) if err != nil { return err } defer rows.Close() cols, err := rows.Columns() if err != nil { return err } result := reflect.MakeSlice(sliceType, 0, 0) for rows.Next() { item := reflect.New(elemType) // Prepare scan destinations scanDest := make([]any, len(cols)) for i, col := range cols { fi := colToField[col] if fi == nil { var dummy any scanDest[i] = &dummy continue } switch fi.Type.Kind() { case reflect.String: var s sql.NullString scanDest[i] = &s case reflect.Int, reflect.Int64: var n sql.NullInt64 scanDest[i] = &n case reflect.Bool: var b sql.NullBool scanDest[i] = &b default: var dummy any scanDest[i] = &dummy } } if err := rows.Scan(scanDest...); err != nil { return err } // Map values to struct fields for i, col := range cols { fi := colToField[col] if fi == nil { continue } field := item.Elem().Field(fi.Index) switch fi.Type.Kind() { case reflect.String: ns := scanDest[i].(*sql.NullString) if ns.Valid { field.SetString(ns.String) } case reflect.Int, reflect.Int64: ni := scanDest[i].(*sql.NullInt64) if ni.Valid { field.SetInt(ni.Int64) } case reflect.Bool: nb := scanDest[i].(*sql.NullBool) if nb.Valid { field.SetBool(nb.Bool) } } } if sliceType.Elem().Kind() == reflect.Ptr { result = reflect.Append(result, item) } else { result = reflect.Append(result, item.Elem()) } } sliceVal.Elem().Set(result) return rows.Err() }