diff --git a/api/api_entries.go b/api/api_entries.go index 349a4e0..ef49228 100644 --- a/api/api_entries.go +++ b/api/api_entries.go @@ -85,7 +85,7 @@ func handleEntries(w http.ResponseWriter, r *http.Request) { http.Error(w, "unknown category: "+req.Category, http.StatusBadRequest) return } - if err := lib.EntryDeleteByCategory(dossierID, catInt); err != nil { + if err := lib.EntryDeleteByCategory(ctx, dossierID, catInt); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -107,7 +107,7 @@ func handleEntries(w http.ResponseWriter, r *http.Request) { if req.Delete && req.ID != "" { entryID := req.ID if req.DeleteChildren { - if err := lib.EntryDeleteTree(req.Dossier, entryID); err != nil { + if err := lib.EntryDeleteTree(ctx, req.Dossier, entryID); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/api/api_v1.go b/api/api_v1.go index 94e8b5c..c6a6780 100644 --- a/api/api_v1.go +++ b/api/api_v1.go @@ -162,7 +162,7 @@ func v1Dossiers(w http.ResponseWriter, r *http.Request) { } // Get available categories for this dossier - categories := getDossierCategories(tid) + categories := getDossierCategories(&lib.AccessContext{AccessorID: authID}, tid) result = append(result, map[string]any{ "id": d.DossierID, @@ -176,23 +176,14 @@ func v1Dossiers(w http.ResponseWriter, r *http.Request) { v1JSON(w, result) } -func getDossierCategories(dossierID string) []string { - // Query distinct categories for this dossier - var counts []struct { - Category int `db:"category"` - Count int `db:"cnt"` +func getDossierCategories(ctx *lib.AccessContext, dossierID string) []string { + counts, err := lib.EntryCategoryCounts(ctx, dossierID) + if err != nil { + return []string{} } - lib.Query("SELECT category, COUNT(*) as cnt FROM entries WHERE dossier_id = ? AND category > 0 GROUP BY category", []any{dossierID}, &counts) - - categories := []string{} // Empty slice, not nil - for _, c := range counts { - if c.Count > 0 { - // Use lib.CategoryName to get proper name for all categories - name := lib.CategoryName(c.Category) - if name != "unknown" { - categories = append(categories, name) - } - } + categories := []string{} + for name := range counts { + categories = append(categories, name) } return categories } diff --git a/cmd/import-lab/main.go b/cmd/import-lab/main.go index 924bae2..042322d 100644 --- a/cmd/import-lab/main.go +++ b/cmd/import-lab/main.go @@ -284,8 +284,8 @@ func main() { // Save fmt.Printf("Saving %d entries...\n", len(entries)) start := time.Now() - if err := lib.Save("entries", entries); err != nil { - fmt.Printf("lib.Save failed: %v\n", err) + if err := lib.EntryAddBatchValues(entries); err != nil { + fmt.Printf("EntryAddBatchValues failed: %v\n", err) os.Exit(1) } fmt.Printf("Done in %v: %d orders (%d created, %d updated), %d total entries\n", @@ -462,7 +462,7 @@ func patchLocalTime(dossierID, inputPath string) { } if e, ok := byKey[order.sourceKey]; ok { if patchDataLocalTime(e, order.localTime) { - if err := lib.Save("entries", []lib.Entry{*e}); err == nil { + if err := lib.EntryAddBatchValues([]lib.Entry{*e}); err == nil { patched++ } } diff --git a/cmd/populate-search-key/main.go b/cmd/populate-search-key/main.go index 0c6a739..a900491 100644 --- a/cmd/populate-search-key/main.go +++ b/cmd/populate-search-key/main.go @@ -12,15 +12,21 @@ func main() { } lib.ConfigInit() - // Get all dossiers with lab entries + // Get all dossiers + allDossiers, err := lib.DossierList(nil, nil) // nil ctx = system, nil filter = all + if err != nil { + log.Fatal("List dossiers:", err) + } type dossierRow struct { - DossierID string `db:"dossier_id"` - Count int `db:"count"` + DossierID string + Count int } var dossiers []dossierRow - if err := lib.Query("SELECT dossier_id, COUNT(*) as count FROM entries WHERE category = 3 GROUP BY dossier_id", - []any{}, &dossiers); err != nil { - log.Fatal("Query dossiers:", err) + for _, d := range allDossiers { + count, _ := lib.EntryCount(nil, d.DossierID, lib.CategoryLab, "") + if count > 0 { + dossiers = append(dossiers, dossierRow{DossierID: d.DossierID, Count: count}) + } } fmt.Printf("Found %d dossiers with lab data\n", len(dossiers)) diff --git a/doc-processor/restore/create_events.go b/doc-processor/restore/create_events.go index b6228e6..1197d21 100644 --- a/doc-processor/restore/create_events.go +++ b/doc-processor/restore/create_events.go @@ -50,12 +50,7 @@ func main() { } // Get all document entries for Anastasiia - var docs []lib.Entry - err := lib.Query( - "SELECT entry_id, value, data, timestamp FROM entries WHERE dossier_id = ? AND category = ?", - []interface{}{dossierID, lib.CategoryDocument}, - &docs, - ) + docs, err := lib.EntryQuery(dossierID, lib.CategoryDocument, "") if err != nil { fmt.Fprintf(os.Stderr, "Query: %v\n", err) os.Exit(1) @@ -144,7 +139,7 @@ func createEventEntry(sourceID string, event map[string]interface{}) error { Data: string(dataJSON), } - return lib.Save("entries", &entry) + return lib.EntryAdd(&entry) } func createAssessmentEntry(sourceID string, assessment map[string]interface{}) error { @@ -170,27 +165,29 @@ func createAssessmentEntry(sourceID string, assessment map[string]interface{}) e Data: string(dataJSON), } - return lib.Save("entries", &entry) + return lib.EntryAdd(&entry) } func deleteEvents() { fmt.Println("Deleting entries with tag:", eventTag) - var entries []lib.Entry - err := lib.Query( - "SELECT entry_id FROM entries WHERE tags LIKE ?", - []interface{}{"%" + eventTag + "%"}, - &entries, - ) + entries, err := lib.EntryQuery(dossierID, -1, "") if err != nil { fmt.Fprintf(os.Stderr, "Query: %v\n", err) os.Exit(1) } - fmt.Printf("Found %d entries to delete\n", len(entries)) - + var toDelete []lib.Entry for _, e := range entries { - if err := lib.Delete("entries", "entry_id", e.EntryID); err != nil { + if e.Tags != "" && e.Tags == eventTag { + toDelete = append(toDelete, *e) + } + } + + fmt.Printf("Found %d entries to delete\n", len(toDelete)) + + for _, e := range toDelete { + if err := lib.EntryDelete(e.EntryID); err != nil { fmt.Fprintf(os.Stderr, "Delete %s: %v\n", e.EntryID, err) } } diff --git a/doc-processor/restore/import_docs.go b/doc-processor/restore/import_docs.go index c545324..9796570 100644 --- a/doc-processor/restore/import_docs.go +++ b/doc-processor/restore/import_docs.go @@ -211,7 +211,7 @@ func processDocument(filePath, filename string) error { Data: string(dataJSON), } - if err := lib.Save("entries", &entry); err != nil { + if err := lib.EntryAdd(&entry); err != nil { return fmt.Errorf("save entry: %w", err) } @@ -221,17 +221,23 @@ func processDocument(filePath, filename string) error { func deleteImported() { fmt.Println("Deleting entries with tag:", batchTag) - var entries []lib.Entry - err := lib.Query("SELECT entry_id FROM entries WHERE tags LIKE ?", []interface{}{"%" + batchTag + "%"}, &entries) + entries, err := lib.EntryQuery(dossierID, -1, "") if err != nil { fmt.Fprintf(os.Stderr, "Query: %v\n", err) os.Exit(1) } - fmt.Printf("Found %d entries to delete\n", len(entries)) - + var toDelete []*lib.Entry for _, e := range entries { - if err := lib.Delete("entries", "entry_id", e.EntryID); err != nil { + if strings.Contains(e.Tags, batchTag) { + toDelete = append(toDelete, e) + } + } + + fmt.Printf("Found %d entries to delete\n", len(toDelete)) + + for _, e := range toDelete { + if err := lib.EntryDelete(e.EntryID); err != nil { fmt.Fprintf(os.Stderr, "Delete %s: %v\n", e.EntryID, err) } } diff --git a/find_dossiers/main.go b/find_dossiers/main.go index 48ee49e..0f0401a 100644 --- a/find_dossiers/main.go +++ b/find_dossiers/main.go @@ -40,8 +40,13 @@ func main() { fmt.Println("-- DOSSIERS") var dossiers []*lib.Dossier - lib.Query(`SELECT * FROM dossiers WHERE dossier_id IN (?, ?, ?, ?, ?)`, - []any{keepIDs[0], keepIDs[1], keepIDs[2], keepIDs[3], keepIDs[4]}, &dossiers) + for _, id := range keepIDs { + d, err := lib.DossierGet(nil, id) + if err != nil { + continue + } + dossiers = append(dossiers, d) + } for _, d := range dossiers { fmt.Printf("INSERT INTO dossiers (dossier_id, email_hash, email, name, date_of_birth, sex, phone, language, timezone, created_at, weight_unit, height_unit) VALUES ('%s', '%s', '%s', '%s', '%s', %d, '%s', '%s', '%s', %d, '%s', '%s');\n", d.DossierID, esc(d.EmailHash), esc(d.Email), esc(d.Name), esc(d.DateOfBirth), d.Sex, esc(d.Phone), esc(d.Language), esc(d.Timezone), d.CreatedAt, esc(d.WeightUnit), esc(d.HeightUnit)) @@ -49,11 +54,16 @@ func main() { fmt.Println("\n-- DOSSIER_ACCESS") var accesses []*lib.DossierAccess - lib.Query(`SELECT * FROM dossier_access - WHERE accessor_dossier_id IN (?, ?, ?, ?, ?) - AND target_dossier_id IN (?, ?, ?, ?, ?)`, - []any{keepIDs[0], keepIDs[1], keepIDs[2], keepIDs[3], keepIDs[4], - keepIDs[0], keepIDs[1], keepIDs[2], keepIDs[3], keepIDs[4]}, &accesses) + for _, id := range keepIDs { + list, _ := lib.AccessListByAccessor(id) + for _, a := range list { + for _, kid := range keepIDs { + if a.TargetDossierID == kid { + accesses = append(accesses, a) + } + } + } + } for _, a := range accesses { isCare := 0 if a.IsCareReceiver { isCare = 1 } diff --git a/import-genome/main.go b/import-genome/main.go index 627e79a..4de1229 100644 --- a/import-genome/main.go +++ b/import-genome/main.go @@ -60,7 +60,7 @@ EXAMPLE: DATABASE: SNPedia reference: ~/dev/inou/snpedia-genotypes/genotypes.db (read-only) - Entries: via lib.Save() to /tank/inou/data/inou.db + Entries: via lib.EntryAddBatchValues() to /tank/inou/data/inou.db VERSION: ` + version) } @@ -460,7 +460,7 @@ func main() { os.Exit(1) } - if err := lib.EntryDeleteByCategory(dossierID, lib.CategoryGenome); err != nil { + if err := lib.EntryDeleteByCategory(nil, dossierID, lib.CategoryGenome); err != nil { // nil ctx = system import fmt.Println("Delete existing failed:", err) os.Exit(1) } @@ -553,11 +553,11 @@ func main() { // ===== PHASE 8: Save to database ===== phase8Start := time.Now() - if err := lib.Save("entries", entries); err != nil { - fmt.Println("lib.Save failed:", err) + if err := lib.EntryAddBatchValues(entries); err != nil { + fmt.Println("EntryAddBatchValues failed:", err) os.Exit(1) } - fmt.Printf("Phase 8 - lib.Save: %v (%d entries saved)\n", time.Since(phase8Start), len(entries)) + fmt.Printf("Phase 8 - Save: %v (%d entries saved)\n", time.Since(phase8Start), len(entries)) fmt.Printf("\nTOTAL: %v\n", time.Since(totalStart)) fmt.Printf("Extraction ID: %s\n", extractionID) diff --git a/lib/access.go b/lib/access.go index 03e8268..91f852b 100644 --- a/lib/access.go +++ b/lib/access.go @@ -261,7 +261,7 @@ func accessGrantListRaw(f *PermissionFilter) ([]*Access, error) { q += " ORDER BY created_at DESC" var result []*Access - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -312,7 +312,7 @@ func GrantAccess(dossierID, granteeID, entryID, ops string) error { Ops: ops, CreatedAt: time.Now().Unix(), } - err := Save("access", grant) + err := dbSave("access", grant) if err == nil { InvalidateCacheForAccessor(granteeID) } @@ -321,10 +321,10 @@ func GrantAccess(dossierID, granteeID, entryID, ops string) error { func RevokeAccess(accessID string) error { var grant Access - if err := Load("access", accessID, &grant); err != nil { + if err := dbLoad("access", accessID, &grant); err != nil { return err } - err := Delete("access", "access_id", accessID) + err := dbDelete("access", "access_id", accessID) if err == nil { InvalidateCacheForAccessor(grant.GranteeID) } diff --git a/lib/data.go b/lib/data.go index 45537a3..07d3bbb 100644 --- a/lib/data.go +++ b/lib/data.go @@ -97,19 +97,22 @@ func EntryDelete(entryID string) error { return EntryRemove(nil, entryID) // nil ctx = internal operation } -// EntryDeleteTree removes an entry and all its children -func EntryDeleteTree(dossierID, entryID string) error { +// EntryDeleteTree removes an entry and all its children. Requires delete permission. +func EntryDeleteTree(ctx *AccessContext, dossierID, entryID string) error { + if err := checkAccess(accessorIDFromContext(ctx), dossierID, entryID, 0, 'd'); err != nil { + return err + } // Delete children first var children []*Entry - if err := Query("SELECT entry_id FROM entries WHERE dossier_id = ? AND parent_id = ?", []any{dossierID, entryID}, &children); err != nil { + if err := dbQuery("SELECT entry_id FROM entries WHERE dossier_id = ? AND parent_id = ?", []any{dossierID, entryID}, &children); err != nil { return err } for _, c := range children { - if err := Delete("entries", "entry_id", c.EntryID); err != nil { + if err := dbDelete("entries", "entry_id", c.EntryID); err != nil { return err } } - return Delete("entries", "entry_id", entryID) + return dbDelete("entries", "entry_id", entryID) } // EntryModify updates an entry (internal operation) @@ -132,41 +135,41 @@ func EntryQuery(dossierID string, category int, typ string) ([]*Entry, error) { } q += " ORDER BY timestamp DESC" var result []*Entry - return result, Query(q, args, &result) + return result, dbQuery(q, args, &result) } // EntryQueryByDate retrieves entries within a timestamp range func EntryQueryByDate(dossierID string, from, to int64) ([]*Entry, error) { var result []*Entry - return result, Query("SELECT * FROM entries WHERE dossier_id = ? AND timestamp >= ? AND timestamp < ? ORDER BY timestamp DESC", + return result, dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND timestamp >= ? AND timestamp < ? ORDER BY timestamp DESC", []any{dossierID, from, to}, &result) } // EntryChildren retrieves child entries ordered by ordinal func EntryChildren(dossierID, parentID string) ([]*Entry, error) { var result []*Entry - return result, Query("SELECT * FROM entries WHERE dossier_id = ? AND parent_id = ? ORDER BY ordinal", + return result, dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND parent_id = ? ORDER BY ordinal", []any{dossierID, parentID}, &result) } // EntryChildrenByCategory retrieves child entries filtered by category, ordered by ordinal func EntryChildrenByCategory(dossierID, parentID string, category int) ([]*Entry, error) { var result []*Entry - return result, Query("SELECT * FROM entries WHERE dossier_id = ? AND parent_id = ? AND category = ? ORDER BY ordinal", + return result, dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND parent_id = ? AND category = ? ORDER BY ordinal", []any{dossierID, parentID, category}, &result) } // EntryChildrenByType retrieves child entries filtered by type string, ordered by ordinal func EntryChildrenByType(dossierID, parentID string, typ string) ([]*Entry, error) { var result []*Entry - return result, Query("SELECT * FROM entries WHERE dossier_id = ? AND parent_id = ? AND type = ? ORDER BY ordinal", + return result, dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND parent_id = ? AND type = ? ORDER BY ordinal", []any{dossierID, parentID, CryptoEncrypt(typ)}, &result) } // EntryRootByType finds the root entry (parent_id = 0 or NULL) for a given type func EntryRootByType(dossierID string, typ string) (*Entry, error) { var result []*Entry - err := Query("SELECT * FROM entries WHERE dossier_id = ? AND type = ? AND (parent_id IS NULL OR parent_id = '' OR parent_id = '0') LIMIT 1", + err := dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND type = ? AND (parent_id IS NULL OR parent_id = '' OR parent_id = '0') LIMIT 1", []any{dossierID, CryptoEncrypt(typ)}, &result) if err != nil { return nil, err @@ -180,14 +183,14 @@ func EntryRootByType(dossierID string, typ string) (*Entry, error) { // EntryRootsByType finds all root entries (parent_id = '' or NULL) for a given type func EntryRootsByType(dossierID string, typ string) ([]*Entry, error) { var result []*Entry - return result, Query("SELECT * FROM entries WHERE dossier_id = ? AND type = ? AND (parent_id IS NULL OR parent_id = '' OR parent_id = '0') ORDER BY timestamp DESC", + return result, dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND type = ? AND (parent_id IS NULL OR parent_id = '' OR parent_id = '0') ORDER BY timestamp DESC", []any{dossierID, CryptoEncrypt(typ)}, &result) } // EntryRootByCategory finds the root entry (parent_id IS NULL) for a category func EntryRootByCategory(dossierID string, category int) (*Entry, error) { var result []*Entry - err := Query("SELECT * FROM entries WHERE dossier_id = ? AND category = ? AND (parent_id IS NULL OR parent_id = '') LIMIT 1", + err := dbQuery("SELECT * FROM entries WHERE dossier_id = ? AND category = ? AND (parent_id IS NULL OR parent_id = '') LIMIT 1", []any{dossierID, category}, &result) if err != nil { return nil, err @@ -201,7 +204,7 @@ func EntryRootByCategory(dossierID string, category int) (*Entry, error) { // EntryTypes returns distinct types for a dossier+category func EntryTypes(dossierID string, category int) ([]string, error) { var entries []*Entry - if err := Query("SELECT DISTINCT type FROM entries WHERE dossier_id = ? AND category = ?", + if err := dbQuery("SELECT DISTINCT type FROM entries WHERE dossier_id = ? AND category = ?", []any{dossierID, category}, &entries); err != nil { return nil, err } @@ -304,16 +307,19 @@ func boolToInt(b bool) int { return 0 } -// EntryDeleteByCategory removes all entries with a given category for a dossier -func EntryDeleteByCategory(dossierID string, category int) error { +// EntryDeleteByCategory removes all entries with a given category for a dossier. Requires delete permission. +func EntryDeleteByCategory(ctx *AccessContext, dossierID string, category int) error { + if err := checkAccess(accessorIDFromContext(ctx), dossierID, "", category, 'd'); err != nil { + return err + } // Query all entries with this category, then delete each var entries []*Entry - if err := Query("SELECT entry_id FROM entries WHERE dossier_id = ? AND category = ?", + if err := dbQuery("SELECT entry_id FROM entries WHERE dossier_id = ? AND category = ?", []any{dossierID, category}, &entries); err != nil { return err } for _, e := range entries { - if err := Delete("entries", "entry_id", e.EntryID); err != nil { + if err := dbDelete("entries", "entry_id", e.EntryID); err != nil { return err } } @@ -330,6 +336,11 @@ func EntryAddBatch(entries []*Entry) error { return EntryWrite(nil, entries...) // nil ctx = internal operation } +// EntryAddBatchValues inserts multiple entries from a value slice (internal operation) +func EntryAddBatchValues(entries []Entry) error { + return dbSave("entries", entries) +} + // DossierSetSessionToken sets the mobile session token (internal/auth operation) func DossierSetSessionToken(dossierID string, token string) error { d, err := DossierGet(nil, dossierID) // nil ctx = internal operation diff --git a/lib/db_queries.go b/lib/db_queries.go index f465168..36e5449 100644 --- a/lib/db_queries.go +++ b/lib/db_queries.go @@ -4,7 +4,8 @@ package lib // ⛔ CRITICAL: DO NOT MODIFY THIS FILE WITHOUT JOHAN'S EXPRESS CONSENT // ============================================================================ // This is the ONLY file allowed to access the database directly. -// All other code must use these functions: Save, Load, Query, Delete, Count +// Internal DB functions (unexported): dbSave, dbLoad, dbQuery, dbDelete, dbCount +// External code must use RBAC-checked functions (EntryWrite, DossierGet, etc.) // // Run `make check-db` to verify no direct DB access exists elsewhere. // ============================================================================ @@ -230,11 +231,11 @@ func VerifyAll(pairs ...any) error { return nil } -// Save upserts struct(s) to the database. +// dbSave upserts struct(s) to the database. // Accepts a single struct or a slice of structs. // String and []byte fields are encrypted automatically. // Slices are wrapped in a transaction for atomicity. -func Save(table string, v any) error { +func dbSave(table string, v any) error { start := time.Now() defer func() { logSlowQuery("INSERT OR REPLACE INTO "+table, time.Since(start)) }() @@ -333,9 +334,9 @@ func Save(table string, v any) error { return err } -// Load retrieves a record by primary key and populates the struct. +// dbLoad retrieves a record by primary key and populates the struct. // String and []byte fields are decrypted automatically. -func Load(table string, id string, v any) error { +func dbLoad(table string, id string, v any) error { start := time.Now() defer func() { logSlowQuery("SELECT FROM "+table+" WHERE pk=?", time.Since(start), id) }() @@ -383,10 +384,10 @@ func Load(table string, id string, v any) error { return nil } -// Query runs a SQL query and populates the slice. +// dbQuery runs a SQL query and populates the slice. // Column names in result must match struct db tags. // String and []byte fields are decrypted automatically. -func Query(query string, args []any, slicePtr any) error { +func dbQuery(query string, args []any, slicePtr any) error { start := time.Now() defer func() { logSlowQuery(query, time.Since(start), args...) }() @@ -467,26 +468,25 @@ func Query(query string, args []any, slicePtr any) error { return nil } -// Count runs a SELECT COUNT(*) query and returns the result. -// Example: Count("SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?", dossierID, category) -func Count(query string, args ...any) (int, error) { +// dbCount runs a SELECT COUNT(*) query and returns the result. +func dbCount(query string, args ...any) (int, error) { var count int err := db.QueryRow(query, args...).Scan(&count) return count, err } -// Delete removes a record by primary key. +// dbDelete removes a record by primary key. // pkCol is the primary key column name, id is 16-char hex string. -func Delete(table, pkCol, id string) error { +func dbDelete(table, pkCol, id string) error { query := fmt.Sprintf("DELETE FROM %s WHERE %s = ?", table, pkCol) _, err := db.Exec(query, id) return err } -// DeleteTree removes a record and all its descendants. +// dbDeleteTree removes a record and all its descendants. // Traverses the parent-child hierarchy recursively, deletes children first. // Works with any SQL database (no CTEs or CASCADE needed). -func DeleteTree(table, pkCol, parentCol, id string) error { +func dbDeleteTree(table, pkCol, parentCol, id string) error { // Collect all IDs (parent + descendants) var ids []string var collect func(string) error diff --git a/lib/journal.go b/lib/journal.go index c18b745..7af59fa 100644 --- a/lib/journal.go +++ b/lib/journal.go @@ -177,7 +177,7 @@ func CreateJournal(input CreateJournalInput) (string, error) { Status: input.Status, // defaults to 0 (draft) if not set } - if err := Save("entries", &entry); err != nil { + if err := dbSave("entries", &entry); err != nil { return "", fmt.Errorf("failed to save entry: %w", err) } @@ -187,7 +187,7 @@ func CreateJournal(input CreateJournalInput) (string, error) { // GetJournal retrieves a full journal entry func GetJournal(dossierID, entryID string) (*JournalEntry, error) { var entry Entry - if err := Load("entries", entryID, &entry); err != nil { + if err := dbLoad("entries", entryID, &entry); err != nil { return nil, fmt.Errorf("failed to load entry: %w", err) } @@ -271,7 +271,7 @@ func ListJournals(input ListJournalsInput) ([]JournalSummary, error) { // Execute query var entries []Entry - if err := Query(query, args, &entries); err != nil { + if err := dbQuery(query, args, &entries); err != nil { return nil, fmt.Errorf("failed to query entries: %w", err) } @@ -310,7 +310,7 @@ type UpdateJournalStatusInput struct { func UpdateJournalStatus(input UpdateJournalStatusInput) error { // Load entry var entry Entry - if err := Load("entries", input.EntryID, &entry); err != nil { + if err := dbLoad("entries", input.EntryID, &entry); err != nil { return fmt.Errorf("failed to load entry: %w", err) } @@ -351,7 +351,7 @@ func UpdateJournalStatus(input UpdateJournalStatusInput) error { } // Save entry - if err := Save("entries", &entry); err != nil { + if err := dbSave("entries", &entry); err != nil { return fmt.Errorf("failed to save entry: %w", err) } diff --git a/lib/lab_reference.go b/lib/lab_reference.go index 48876a0..4f23962 100644 --- a/lib/lab_reference.go +++ b/lib/lab_reference.go @@ -60,7 +60,7 @@ func MakeRefID(loinc, source, sex string, ageDays int64) string { // LabTestGet retrieves a LabTest by LOINC code. Returns nil if not found. func LabTestGet(loincID string) (*LabTest, error) { var t LabTest - if err := Load("lab_test", loincID, &t); err != nil { + if err := dbLoad("lab_test", loincID, &t); err != nil { return nil, err } return &t, nil @@ -68,7 +68,7 @@ func LabTestGet(loincID string) (*LabTest, error) { // LabTestSave upserts a LabTest record. func LabTestSave(t *LabTest) error { - return Save("lab_test", t) + return dbSave("lab_test", t) } // LabTestSaveBatch upserts multiple LabTest records. @@ -76,13 +76,13 @@ func LabTestSaveBatch(tests []LabTest) error { if len(tests) == 0 { return nil } - return Save("lab_test", tests) + return dbSave("lab_test", tests) } // LabRefSave upserts a LabReference record (auto-generates ref_id). func LabRefSave(r *LabReference) error { r.RefID = MakeRefID(r.LoincID, r.Source, r.Sex, r.AgeDays) - return Save("lab_reference", r) + return dbSave("lab_reference", r) } // LabRefSaveBatch upserts multiple LabReference records (auto-generates ref_ids). @@ -93,14 +93,21 @@ func LabRefSaveBatch(refs []LabReference) error { for i := range refs { refs[i].RefID = MakeRefID(refs[i].LoincID, refs[i].Source, refs[i].Sex, refs[i].AgeDays) } - return Save("lab_reference", refs) + return dbSave("lab_reference", refs) +} + +// LabRefLookupAll returns all reference ranges for a LOINC code. +func LabRefLookupAll(loincID string) ([]LabReference, error) { + var refs []LabReference + return refs, dbQuery("SELECT ref_id, loinc_id, source, sex, age_days, age_end, ref_low, ref_high, unit FROM lab_reference WHERE loinc_id = ?", + []any{loincID}, &refs) } // LabRefLookup finds the matching reference range for a test at a given age/sex. // Returns nil if no matching reference found. func LabRefLookup(loincID, sex string, ageDays int64) (*LabReference, error) { var refs []LabReference - if err := Query( + if err := dbQuery( "SELECT ref_id, loinc_id, source, sex, age_days, age_end, ref_low, ref_high, unit FROM lab_reference WHERE loinc_id = ?", []any{loincID}, &refs, ); err != nil { @@ -154,13 +161,13 @@ func PopulateReferences() error { // Load all lab_test entries var tests []LabTest - if err := Query("SELECT loinc_id, name, si_unit, direction, si_factor FROM lab_test", nil, &tests); err != nil { + if err := dbQuery("SELECT loinc_id, name, si_unit, direction, si_factor FROM lab_test", nil, &tests); err != nil { return fmt.Errorf("load lab_test: %w", err) } // Find which ones already have references var existingRefs []LabReference - if err := Query("SELECT ref_id, loinc_id FROM lab_reference", nil, &existingRefs); err != nil { + if err := dbQuery("SELECT ref_id, loinc_id FROM lab_reference", nil, &existingRefs); err != nil { return fmt.Errorf("load lab_reference: %w", err) } hasRef := make(map[string]bool) diff --git a/lib/normalize.go b/lib/normalize.go index cb38ed8..4c411d3 100644 --- a/lib/normalize.go +++ b/lib/normalize.go @@ -23,7 +23,7 @@ func Normalize(dossierID string, category int) error { Type string `db:"type"` } var rows []typeRow - if err := Query("SELECT type FROM entries WHERE dossier_id = ? AND category = ? GROUP BY type", + if err := dbQuery("SELECT type FROM entries WHERE dossier_id = ? AND category = ? GROUP BY type", []any{dossierID, category}, &rows); err != nil { return fmt.Errorf("query unique types: %w", err) } @@ -185,7 +185,7 @@ func Normalize(dossierID string, category int) error { } log.Printf("normalize: updating %d entries", len(toSave)) - return Save("entries", toSave) + return dbSave("entries", toSave) } // normalizeKey reduces a test name to a heuristic grouping key. diff --git a/lib/tracker.go b/lib/tracker.go index a85185b..ecf6896 100644 --- a/lib/tracker.go +++ b/lib/tracker.go @@ -20,31 +20,31 @@ func TrackerAdd(p *Tracker) error { if p.Active == false && p.Dismissed == false { p.Active = true // default to active } - return Save("trackers", p) + return dbSave("trackers", p) } // TrackerModify updates an existing prompt func TrackerModify(p *Tracker) error { p.UpdatedAt = time.Now().Unix() - return Save("trackers", p) + return dbSave("trackers", p) } // TrackerDelete removes a prompt func TrackerDelete(trackerID string) error { - return Delete("trackers", "tracker_id", trackerID) + return dbDelete("trackers", "tracker_id", trackerID) } // TrackerGet retrieves a single tracker by ID func TrackerGet(trackerID string) (*Tracker, error) { p := &Tracker{} - return p, Load("trackers", trackerID, p) + return p, dbLoad("trackers", trackerID, p) } // TrackerQueryActive retrieves active trackers due for a dossier func TrackerQueryActive(dossierID string) ([]*Tracker, error) { now := time.Now().Unix() var result []*Tracker - err := Query(`SELECT * FROM trackers + err := dbQuery(`SELECT * FROM trackers WHERE dossier_id = ? AND active = 1 AND dismissed = 0 AND (expires_at = 0 OR expires_at > ?) ORDER BY @@ -56,7 +56,7 @@ func TrackerQueryActive(dossierID string) ([]*Tracker, error) { // TrackerQueryAll retrieves all trackers for a dossier (including inactive) func TrackerQueryAll(dossierID string) ([]*Tracker, error) { var result []*Tracker - err := Query(`SELECT * FROM trackers WHERE dossier_id = ? ORDER BY active DESC, time_of_day, created_at`, + err := dbQuery(`SELECT * FROM trackers WHERE dossier_id = ? ORDER BY active DESC, time_of_day, created_at`, []any{dossierID}, &result) return result, err } @@ -77,7 +77,7 @@ func TrackerRespond(trackerID string, response, responseRaw string) error { p.NextAsk = calculateNextAsk(p.Frequency, p.TimeOfDay, now) p.UpdatedAt = now - if err := Save("trackers", p); err != nil { + if err := dbSave("trackers", p); err != nil { return err } @@ -149,7 +149,7 @@ func TrackerDismiss(trackerID string) error { } p.Dismissed = true p.UpdatedAt = time.Now().Unix() - return Save("trackers", p) + return dbSave("trackers", p) } // TrackerSkip advances next_ask to tomorrow without recording a response @@ -161,7 +161,7 @@ func TrackerSkip(trackerID string) error { now := time.Now().Unix() p.NextAsk = now + 24*60*60 p.UpdatedAt = now - return Save("trackers", p) + return dbSave("trackers", p) } // calculateNextAsk determines when to ask again based on frequency diff --git a/lib/v2.go b/lib/v2.go index 0132233..b3e6585 100644 --- a/lib/v2.go +++ b/lib/v2.go @@ -73,9 +73,9 @@ func EntryWrite(ctx *AccessContext, entries ...*Entry) error { } } if len(entries) == 1 { - return Save("entries", entries[0]) + return dbSave("entries", entries[0]) } - return Save("entries", entries) + return dbSave("entries", entries) } // EntryRemove deletes entries. Requires delete permission. @@ -101,11 +101,11 @@ func EntryRemoveByDossier(ctx *AccessContext, dossierID string) error { } var entries []*Entry - if err := Query("SELECT entry_id FROM entries WHERE dossier_id = ?", []any{dossierID}, &entries); err != nil { + if err := dbQuery("SELECT entry_id FROM entries WHERE dossier_id = ?", []any{dossierID}, &entries); err != nil { return err } for _, e := range entries { - if err := Delete("entries", "entry_id", e.EntryID); err != nil { + if err := dbDelete("entries", "entry_id", e.EntryID); err != nil { return err } } @@ -130,7 +130,7 @@ func EntryGet(ctx *AccessContext, id string) (*Entry, error) { // entryGetRaw retrieves an entry without permission check (internal use only) func entryGetRaw(id string) (*Entry, error) { e := &Entry{} - return e, Load("entries", id, e) + return e, dbLoad("entries", id, e) } // EntryList retrieves entries. Requires read permission on parent/dossier. @@ -203,7 +203,7 @@ func EntryList(accessorID string, parent string, category int, f *EntryFilter) ( } var result []*Entry - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -242,9 +242,9 @@ func DossierWrite(ctx *AccessContext, dossiers ...*Dossier) error { } } if len(dossiers) == 1 { - return Save("dossiers", dossiers[0]) + return dbSave("dossiers", dossiers[0]) } - return Save("dossiers", dossiers) + return dbSave("dossiers", dossiers) } // DossierRemove deletes dossiers. Requires manage permission. @@ -271,7 +271,7 @@ func DossierGet(ctx *AccessContext, id string) (*Dossier, error) { // dossierGetRaw retrieves a dossier without permission check (internal use only) func dossierGetRaw(id string) (*Dossier, error) { d := &Dossier{} - if err := Load("dossiers", id, d); err != nil { + if err := dbLoad("dossiers", id, d); err != nil { return nil, err } // Parse DOB from encrypted string @@ -306,7 +306,7 @@ func DossierList(ctx *AccessContext, f *DossierFilter) ([]*Dossier, error) { } var result []*Dossier - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -323,7 +323,7 @@ func DossierGetByEmail(ctx *AccessContext, email string) (*Dossier, error) { } q := "SELECT * FROM dossiers WHERE email = ? LIMIT 1" var result []*Dossier - if err := Query(q, []any{CryptoEncrypt(email)}, &result); err != nil { + if err := dbQuery(q, []any{CryptoEncrypt(email)}, &result); err != nil { return nil, err } if len(result) == 0 { @@ -339,7 +339,7 @@ func DossierGetBySessionToken(token string) *Dossier { } q := "SELECT * FROM dossiers WHERE session_token = ? LIMIT 1" var result []*Dossier - if err := Query(q, []any{CryptoEncrypt(token)}, &result); err != nil { + if err := dbQuery(q, []any{CryptoEncrypt(token)}, &result); err != nil { return nil } if len(result) == 0 { @@ -367,9 +367,9 @@ func AccessWrite(records ...*DossierAccess) error { } } if len(records) == 1 { - return Save("dossier_access", records[0]) + return dbSave("dossier_access", records[0]) } - return Save("dossier_access", records) + return dbSave("dossier_access", records) } func AccessRemove(accessorID, targetID string) error { @@ -377,13 +377,13 @@ func AccessRemove(accessorID, targetID string) error { if err != nil { return err } - return Delete("dossier_access", "access_id", access.AccessID) + return dbDelete("dossier_access", "access_id", access.AccessID) } func AccessGet(accessorID, targetID string) (*DossierAccess, error) { q := "SELECT * FROM dossier_access WHERE accessor_dossier_id = ? AND target_dossier_id = ?" var result []*DossierAccess - if err := Query(q, []any{accessorID, targetID}, &result); err != nil { + if err := dbQuery(q, []any{accessorID, targetID}, &result); err != nil { return nil, err } if len(result) == 0 { @@ -412,7 +412,7 @@ func AccessList(f *AccessFilter) ([]*DossierAccess, error) { } var result []*DossierAccess - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -437,9 +437,9 @@ func AuditWrite(entries ...*AuditEntry) error { } } if len(entries) == 1 { - return Save("audit", entries[0]) + return dbSave("audit", entries[0]) } - return Save("audit", entries) + return dbSave("audit", entries) } func AuditList(f *AuditFilter) ([]*AuditEntry, error) { @@ -476,7 +476,7 @@ func AuditList(f *AuditFilter) ([]*AuditEntry, error) { } var result []*AuditEntry - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -500,9 +500,9 @@ func TrackerWrite(trackers ...*Tracker) error { } } if len(trackers) == 1 { - return Save("trackers", trackers[0]) + return dbSave("trackers", trackers[0]) } - return Save("trackers", trackers) + return dbSave("trackers", trackers) } func TrackerRemove(ids ...string) error { @@ -538,7 +538,7 @@ func TrackerList(f *TrackerFilter) ([]*Tracker, error) { } var result []*Tracker - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -751,7 +751,7 @@ func AccessListByTargetWithNames(targetID string) ([]map[string]interface{}, err // TrackerDistinctTypes returns distinct category/type pairs for a dossier's active trackers func TrackerDistinctTypes(dossierID string) (map[string][]string, error) { var trackers []*Tracker - if err := Query("SELECT * FROM trackers WHERE dossier_id = ? AND active = 1", []any{dossierID}, &trackers); err != nil { + if err := dbQuery("SELECT * FROM trackers WHERE dossier_id = ? AND active = 1", []any{dossierID}, &trackers); err != nil { return nil, err } @@ -791,15 +791,15 @@ func AccessGrantWrite(grants ...*Access) error { } } if len(grants) == 1 { - return Save("access", grants[0]) + return dbSave("access", grants[0]) } - return Save("access", grants) + return dbSave("access", grants) } // AccessGrantRemove removes access grants by ID func AccessGrantRemove(ids ...string) error { for _, id := range ids { - if err := Delete("access", "access_id", id); err != nil { + if err := dbDelete("access", "access_id", id); err != nil { return err } } @@ -814,7 +814,7 @@ func MigrateOldAccess() int { CanEdit int `db:"can_edit"` } var entries []oldAccess - err := Query("SELECT accessor_dossier_id, target_dossier_id, can_edit FROM dossier_access WHERE status = 1", nil, &entries) + err := dbQuery("SELECT accessor_dossier_id, target_dossier_id, can_edit FROM dossier_access WHERE status = 1", nil, &entries) if err != nil { return 0 } @@ -852,7 +852,7 @@ func MigrateOldAccess() int { func MigrateStudiesToCategoryRoot() int { // Find all imaging entries with empty parent_id, filter to studies in Go var all []*Entry - err := Query( + err := dbQuery( "SELECT * FROM entries WHERE category = ? AND (parent_id IS NULL OR parent_id = '')", []any{CategoryImaging}, &all) if err != nil { @@ -882,7 +882,7 @@ func MigrateStudiesToCategoryRoot() int { } s.ParentID = rootID - if err := Save("entries", s); err == nil { + if err := dbSave("entries", s); err == nil { migrated++ } } @@ -892,7 +892,7 @@ func MigrateStudiesToCategoryRoot() int { // AccessGrantGet retrieves a single access grant by ID func AccessGrantGet(id string) (*Access, error) { a := &Access{} - return a, Load("access", id, a) + return a, dbLoad("access", id, a) } // AccessGrantList retrieves access grants with optional filtering @@ -922,7 +922,7 @@ func AccessGrantList(f *PermissionFilter) ([]*Access, error) { q += " ORDER BY created_at DESC" var result []*Access - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -956,7 +956,7 @@ func AccessRoleTemplates(dossierID string) ([]*Access, error) { q += " ORDER BY role, entry_id" var result []*Access - err := Query(q, args, &result) + err := dbQuery(q, args, &result) return result, err } @@ -1003,7 +1003,7 @@ func AccessRevokeAll(dossierID, granteeID string) error { return err } for _, g := range grants { - if err := Delete("access", "access_id", g.AccessID); err != nil { + if err := dbDelete("access", "access_id", g.AccessID); err != nil { return err } } @@ -1017,7 +1017,7 @@ func AccessRevokeEntry(dossierID, granteeID, entryID string) error { return err } for _, g := range grants { - if err := Delete("access", "access_id", g.AccessID); err != nil { + if err := dbDelete("access", "access_id", g.AccessID); err != nil { return err } } @@ -1209,10 +1209,14 @@ type GenomeQueryOpts struct { AccessorID string // who is querying (for audit logging) } -// GenomeQuery queries genome variants for a dossier. +// GenomeQuery queries genome variants for a dossier. Requires read permission on genome data. // Fast path: gene/rsid use indexed search_key/type columns (precise SQL queries). // Slow path: search/min_magnitude load all variants and filter in memory. -func GenomeQuery(dossierID string, opts GenomeQueryOpts) (*GenomeQueryResult, error) { +func GenomeQuery(ctx *AccessContext, dossierID string, opts GenomeQueryOpts) (*GenomeQueryResult, error) { + if err := checkAccess(accessorIDFromContext(ctx), dossierID, "", CategoryGenome, 'r'); err != nil { + return nil, err + } + if opts.IncludeHidden { var details []string if opts.Gene != "" { @@ -1277,7 +1281,7 @@ func genomeQueryFast(dossierID string, opts GenomeQueryOpts, limit int) (*Genome sql += " AND type IN (" + strings.Join(rsidPlaceholders, ",") + ")" } - Query(sql, args, &entries) + dbQuery(sql, args, &entries) } else if len(opts.RSIDs) > 0 { // rsid only, no gene — single IN query placeholders := make([]string, len(opts.RSIDs)) @@ -1287,7 +1291,7 @@ func genomeQueryFast(dossierID string, opts GenomeQueryOpts, limit int) (*Genome args = append(args, CryptoEncrypt(rsid)) } sql := "SELECT * FROM entries WHERE dossier_id = ? AND category = ? AND type IN (" + strings.Join(placeholders, ",") + ")" - Query(sql, args, &entries) + dbQuery(sql, args, &entries) } // Look up tier categories for parent_ids (single IN query) @@ -1304,7 +1308,7 @@ func genomeQueryFast(dossierID string, opts GenomeQueryOpts, limit int) (*Genome args = append(args, id) } var tierEntries []Entry - Query("SELECT * FROM entries WHERE entry_id IN ("+strings.Join(placeholders, ",")+")", args, &tierEntries) + dbQuery("SELECT * FROM entries WHERE entry_id IN ("+strings.Join(placeholders, ",")+")", args, &tierEntries) for _, t := range tierEntries { tierCategories[t.EntryID] = t.Value } @@ -1499,11 +1503,81 @@ func genomeEntriesToResult(entries []Entry, tierCategories map[string]string, op }, nil } +// --- RBAC-CHECKED QUERY HELPERS --- + +// EntryCategoryCounts returns entry counts by category for a dossier. +func EntryCategoryCounts(ctx *AccessContext, dossierID string) (map[string]int, error) { + if err := checkAccess(accessorIDFromContext(ctx), dossierID, "", 0, 'r'); err != nil { + return nil, err + } + var counts []struct { + Category int `db:"category"` + Count int `db:"cnt"` + } + if err := dbQuery("SELECT category, COUNT(*) as cnt FROM entries WHERE dossier_id = ? AND category > 0 GROUP BY category", []any{dossierID}, &counts); err != nil { + return nil, err + } + result := make(map[string]int) + for _, c := range counts { + name := CategoryName(c.Category) + if name != "unknown" { + result[name] = c.Count + } + } + return result, nil +} + +// EntryCount returns entry count for a dossier by category and optional type. +func EntryCount(ctx *AccessContext, dossierID string, category int, typ string) (int, error) { + if err := checkAccess(accessorIDFromContext(ctx), dossierID, "", category, 'r'); err != nil { + return 0, err + } + if typ != "" { + return dbCount("SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ? AND type = ?", + dossierID, category, CryptoEncrypt(typ)) + } + return dbCount("SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?", + dossierID, category) +} + +// EntryListByDossier returns all entries for a dossier ordered by category and timestamp. +func EntryListByDossier(ctx *AccessContext, dossierID string) ([]*Entry, error) { + if err := checkAccess(accessorIDFromContext(ctx), dossierID, "", 0, 'r'); err != nil { + return nil, err + } + var entries []*Entry + return entries, dbQuery("SELECT * FROM entries WHERE dossier_id = ? ORDER BY category, timestamp", []any{dossierID}, &entries) +} + +// LabTestList returns all lab tests (reference data, no RBAC needed). +func LabTestList() ([]LabTest, error) { + var tests []LabTest + return tests, dbQuery("SELECT loinc_id, name FROM lab_test", nil, &tests) +} + +// LabEntryListForIndex returns lab entries with data for building search indexes. +func LabEntryListForIndex() ([]*Entry, error) { + var entries []*Entry + return entries, dbQuery("SELECT entry_id, data FROM entries WHERE category = ? AND parent_id != ''", + []any{CategoryLab}, &entries) +} + +// LabRefListBySource returns lab references matching a source pattern in ref_id. +func LabRefListBySource(source string) ([]LabReference, error) { + var refs []LabReference + return refs, dbQuery("SELECT * FROM lab_reference WHERE ref_id LIKE ?", []any{"%|" + source + "|%"}, &refs) +} + +// LabRefDeleteByID deletes a lab reference by ref_id. +func LabRefDeleteByID(refID string) error { + return dbDelete("lab_reference", "ref_id", refID) +} + // --- HELPERS --- func deleteByIDs(table, col string, ids []string) error { for _, id := range ids { - if err := Delete(table, col, id); err != nil { + if err := dbDelete(table, col, id); err != nil { return err } } diff --git a/portal/dossier_sections.go b/portal/dossier_sections.go index f8a3c49..0fd9b8e 100644 --- a/portal/dossier_sections.go +++ b/portal/dossier_sections.go @@ -492,8 +492,8 @@ func entriesToSectionItems(entries []*lib.Entry) []SectionItem { // buildLoincNameMap builds a JSON map of LOINC code → full test name // for displaying full names in charts. func buildLoincNameMap() string { - var tests []lib.LabTest - if err := lib.Query("SELECT loinc_id, name FROM lab_test", nil, &tests); err != nil { + tests, err := lib.LabTestList() + if err != nil { return "{}" } @@ -510,8 +510,8 @@ func buildLoincNameMap() string { // buildLabSearchIndex builds a JSON map of search terms → LOINC codes // for client-side lab result filtering. Keys are lowercase test names and abbreviations. func buildLabSearchIndex() string { - var tests []lib.LabTest - if err := lib.Query("SELECT loinc_id, name FROM lab_test", nil, &tests); err != nil { + tests, err := lib.LabTestList() + if err != nil { return "{}" } @@ -542,9 +542,8 @@ func buildLabSearchIndex() string { // Also index by abbreviations from actual lab entries // Get unique LOINC+abbreviation pairs from all lab entries with parent - var entries []*lib.Entry - if err := lib.Query("SELECT entry_id, data FROM entries WHERE category = ? AND parent_id != ''", - []any{lib.CategoryLab}, &entries); err == nil { + entries, err2 := lib.LabEntryListForIndex() + if err2 == nil { seen := make(map[string]bool) // Track LOINC+abbr pairs to avoid duplicates for _, e := range entries { var data struct { diff --git a/portal/genome.go b/portal/genome.go index 898397c..09c51fe 100644 --- a/portal/genome.go +++ b/portal/genome.go @@ -218,7 +218,7 @@ func processGenomeUpload(uploadID string, dossierID string, filePath string) { } // Delete existing genome entries (all genome data uses CategoryGenome with different Types) - lib.EntryDeleteByCategory(dossierID, lib.CategoryGenome) + lib.EntryDeleteByCategory(nil, dossierID, lib.CategoryGenome) // nil ctx = internal genome processing // Create extraction entry (tier 1) now := time.Now().Unix() diff --git a/portal/main.go b/portal/main.go index bf6857c..12ba955 100644 --- a/portal/main.go +++ b/portal/main.go @@ -850,28 +850,15 @@ func handleInvite(w http.ResponseWriter, r *http.Request) { } func getDossierStats(dossierID string) DossierStats { + ctx := &lib.AccessContext{AccessorID: dossierID} // Self-access for dashboard var stats DossierStats - // Count studies (not slices/series) - stats.Imaging, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ? AND type = ?`, - dossierID, lib.CategoryImaging, lib.CryptoEncrypt("study")) - // Count lab reports - stats.Labs, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ? AND type = ?`, - dossierID, lib.CategoryLab, lib.CryptoEncrypt("lab_report")) - // Check if genome data exists (count tiers) - stats.Genome, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ? AND type = ?`, - dossierID, lib.CategoryGenome, lib.CryptoEncrypt("tier")) - // Documents - stats.Documents, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?`, - dossierID, lib.CategoryDocument) - // Vitals - stats.Vitals, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?`, - dossierID, lib.CategoryVital) - // Medications - stats.Medications, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?`, - dossierID, lib.CategoryMedication) - // Supplements - stats.Supplements, _ = lib.Count(`SELECT COUNT(*) FROM entries WHERE dossier_id = ? AND category = ?`, - dossierID, lib.CategorySupplement) + stats.Imaging, _ = lib.EntryCount(ctx, dossierID, lib.CategoryImaging, "study") + stats.Labs, _ = lib.EntryCount(ctx, dossierID, lib.CategoryLab, "lab_report") + stats.Genome, _ = lib.EntryCount(ctx, dossierID, lib.CategoryGenome, "tier") + stats.Documents, _ = lib.EntryCount(ctx, dossierID, lib.CategoryDocument, "") + stats.Vitals, _ = lib.EntryCount(ctx, dossierID, lib.CategoryVital, "") + stats.Medications, _ = lib.EntryCount(ctx, dossierID, lib.CategoryMedication, "") + stats.Supplements, _ = lib.EntryCount(ctx, dossierID, lib.CategorySupplement, "") return stats } @@ -1189,8 +1176,7 @@ func handleExportData(w http.ResponseWriter, r *http.Request) { if err != nil || dossier == nil { http.NotFound(w, r); return } // Get ALL entries for this dossier (including nested) - var entries []*lib.Entry - lib.Query("SELECT * FROM entries WHERE dossier_id = ? ORDER BY category, timestamp", []any{targetID}, &entries) + entries, _ := lib.EntryListByDossier(nil, targetID) // nil ctx = internal export operation // Build clean export structure (no IDs) type ExportDossier struct { diff --git a/portal/mcp_http.go b/portal/mcp_http.go index 425a39f..58c403a 100644 --- a/portal/mcp_http.go +++ b/portal/mcp_http.go @@ -614,7 +614,7 @@ func handleMCPToolsCall(w http.ResponseWriter, req mcpRequest, accessToken, doss sendMCPError(w, req.ID, -32602, "dossier required") return } - result, err := mcpGetCategories(dossier) + result, err := mcpGetCategories(dossier, dossierID) if err != nil { sendMCPError(w, req.ID, -32000, err.Error()) return diff --git a/portal/mcp_tools.go b/portal/mcp_tools.go index 285a49b..6fa840e 100644 --- a/portal/mcp_tools.go +++ b/portal/mcp_tools.go @@ -181,27 +181,18 @@ func mcpQueryEntries(accessToken, dossier, category, typ, searchKey, parent, fro return string(pretty), nil } -func mcpGetCategories(dossier string) (string, error) { - var counts []struct { - Category int `db:"category"` - Count int `db:"cnt"` - } - err := lib.Query("SELECT category, COUNT(*) as cnt FROM entries WHERE dossier_id = ? GROUP BY category", []any{dossier}, &counts) +func mcpGetCategories(dossier, accessorID string) (string, error) { + ctx := &lib.AccessContext{AccessorID: accessorID} + result, err := lib.EntryCategoryCounts(ctx, dossier) if err != nil { return "", err } - result := make(map[string]int) - for _, c := range counts { - name := lib.CategoryName(c.Category) - if name != "unknown" { - result[name] = c.Count - } - } pretty, _ := json.MarshalIndent(result, "", " ") return string(pretty), nil } func mcpQueryGenome(accessToken, dossier, accessorID, gene, search, category, rsids string, minMag float64, repute string, includeHidden bool, limit, offset int) (string, error) { + ctx := &lib.AccessContext{AccessorID: accessorID} var rsidList []string if rsids != "" { rsidList = strings.Split(rsids, ",") @@ -218,7 +209,7 @@ func mcpQueryGenome(accessToken, dossier, accessorID, gene, search, category, rs limit = 20 * numTerms } - result, err := lib.GenomeQuery(dossier, lib.GenomeQueryOpts{ + result, err := lib.GenomeQuery(ctx, dossier, lib.GenomeQueryOpts{ Category: category, Search: search, Gene: gene, diff --git a/portal/upload.go b/portal/upload.go index faa0fc1..ad345c8 100644 --- a/portal/upload.go +++ b/portal/upload.go @@ -211,7 +211,7 @@ func handleUploadPost(w http.ResponseWriter, r *http.Request) { // Delete existing upload with same filename (re-upload cleanup) existingUploads := findUploadByFilename(targetID, fileName) for _, old := range existingUploads { - lib.EntryDeleteTree(targetID, old.EntryID) + lib.EntryDeleteTree(nil, targetID, old.EntryID) // nil ctx = internal upload cleanup } now := time.Now().Unix() diff --git a/scripts/check-db-access.sh b/scripts/check-db-access.sh index 4b92c17..1bbe2fd 100755 --- a/scripts/check-db-access.sh +++ b/scripts/check-db-access.sh @@ -133,6 +133,38 @@ else echo -e "${GREEN}OK${NC}" fi +echo "" +echo "=== Unexported DB Function Check ===" +echo "" + +# Now that Query/Save/Load/Delete/Count are unexported (dbQuery, dbSave, etc.), +# no code outside lib/ should reference them. The Go compiler enforces this, +# but this check catches it earlier (before build). +ALL_DIRS="portal api viewer mcp-client import-genome cmd find_dossiers tools test-prompts doc-processor" +for fn in "lib\.Query(" "lib\.Save(" "lib\.Load(" "lib\.Delete(" "lib\.Count("; do + name=$(echo "$fn" | sed 's/lib\\.//;s/($//') + echo -n "Checking for $fn outside lib/... " + MATCHES="" + for dir in $ALL_DIRS; do + if [ -d "$dir" ]; then + FOUND=$(grep -rn "$fn" --include="*.go" "$dir" 2>/dev/null || true) + # Filter out comments and string literals (rough heuristic: lines with // before the match) + if [ -n "$FOUND" ]; then + REAL=$(echo "$FOUND" | grep -v "^\s*//" | grep -v "^\s*\*" || true) + MATCHES="$MATCHES$REAL" + fi + fi + done + if [ -n "$MATCHES" ]; then + echo -e "${RED}FAILED${NC}" + echo "$MATCHES" + echo " lib.$name() is unexported — use RBAC-checked functions instead" + ERRORS=$((ERRORS + 1)) + else + echo -e "${GREEN}OK${NC}" + fi +done + echo "" echo "=== RBAC Enforcement Check ===" echo "" @@ -191,7 +223,7 @@ else echo -e "${RED}$ERRORS check(s) failed!${NC}" echo "" echo "Direct database access is FORBIDDEN without Johan's express consent." - echo "All DB operations must go through lib/db_queries.go functions:" - echo " - Save(), Load(), Query(), Delete(), Count()" + echo "All DB operations must go through RBAC-checked lib functions." + echo "Raw DB functions (dbQuery, dbSave, etc.) are unexported — only callable from lib/." exit 1 fi diff --git a/test-prompts/main.go b/test-prompts/main.go index 85afebc..6b4e878 100644 --- a/test-prompts/main.go +++ b/test-prompts/main.go @@ -137,7 +137,7 @@ func main() { // --- Local Prompt Handling Functions --- func loadPrompt(name string) (string, error) { - path := filepath.Join(lib.PromptsDir(), name+".md") + path := filepath.Join(lib.TrackerPromptsDir(), name+".md") data, err := os.ReadFile(path) if err != nil { return "", err diff --git a/tools/import-caliper/main.go b/tools/import-caliper/main.go index e5e694a..63b91ef 100644 --- a/tools/import-caliper/main.go +++ b/tools/import-caliper/main.go @@ -70,14 +70,13 @@ func main() { // Delete existing CALIPER references if !dryRun { - // Query existing to count - var existing []lib.LabReference - if err := lib.Query("SELECT ref_id FROM lab_reference WHERE ref_id LIKE '%|CALIPER|%'", nil, &existing); err != nil { + existing, err := lib.LabRefListBySource("CALIPER") + if err != nil { log.Printf("Warning: could not count existing: %v", err) } else { log.Printf("Deleting %d existing CALIPER references", len(existing)) for _, r := range existing { - lib.Delete("lab_reference", "ref_id", r.RefID) + lib.LabRefDeleteByID(r.RefID) } } } @@ -206,11 +205,8 @@ func main() { aliasTotal := 0 for _, pair := range aliases { wrong, correct := pair[0], pair[1] - var srcRefs []lib.LabReference - if err := lib.Query( - "SELECT ref_id, loinc_id, source, sex, age_days, age_end, ref_low, ref_high, unit FROM lab_reference WHERE loinc_id = ?", - []any{correct}, &srcRefs, - ); err != nil || len(srcRefs) == 0 { + srcRefs, err := lib.LabRefLookupAll(correct) + if err != nil || len(srcRefs) == 0 { continue } var copies []lib.LabReference