246 lines
6.2 KiB
Go
246 lines
6.2 KiB
Go
package rbac
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
|
|
"dealroom/internal/model"
|
|
)
|
|
|
|
// Engine handles role-based access control
|
|
type Engine struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
// New creates a new RBAC engine
|
|
func New(db *sql.DB) *Engine {
|
|
return &Engine{db: db}
|
|
}
|
|
|
|
// GrantAccess grants permissions to a user for an entry
|
|
func (e *Engine) GrantAccess(entryID, userID string, permissions int, grantedBy string) error {
|
|
query := `
|
|
INSERT OR REPLACE INTO access (id, entry_id, user_id, permissions, granted_by, granted_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
`
|
|
|
|
id := fmt.Sprintf("%s:%s", entryID, userID)
|
|
_, err := e.db.Exec(query, id, entryID, userID, permissions, grantedBy, model.Now())
|
|
return err
|
|
}
|
|
|
|
// RevokeAccess removes all permissions for a user on an entry
|
|
func (e *Engine) RevokeAccess(entryID, userID string) error {
|
|
query := `DELETE FROM access WHERE entry_id = ? AND user_id = ?`
|
|
_, err := e.db.Exec(query, entryID, userID)
|
|
return err
|
|
}
|
|
|
|
// CheckAccess verifies if a user has specific permissions for an entry
|
|
func (e *Engine) CheckAccess(userID, entryID string, permission int) (bool, error) {
|
|
// Admins have all permissions
|
|
if isAdmin, err := e.isUserAdmin(userID); err != nil {
|
|
return false, err
|
|
} else if isAdmin {
|
|
return true, nil
|
|
}
|
|
|
|
// Check direct permissions
|
|
access, err := e.getUserAccess(userID, entryID)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return false, err
|
|
}
|
|
|
|
if access != nil && access.HasPermission(permission) {
|
|
return true, nil
|
|
}
|
|
|
|
// Check inherited permissions from deal room
|
|
if inherited, err := e.checkInheritedAccess(userID, entryID, permission); err != nil {
|
|
return false, err
|
|
} else if inherited {
|
|
return true, nil
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
// GetUserAccess returns the access record for a user on an entry
|
|
func (e *Engine) GetUserAccess(userID, entryID string) (*model.Access, error) {
|
|
return e.getUserAccess(userID, entryID)
|
|
}
|
|
|
|
// GetEntryAccess returns all access records for an entry
|
|
func (e *Engine) GetEntryAccess(entryID string) ([]*model.Access, error) {
|
|
query := `
|
|
SELECT id, entry_id, user_id, permissions, granted_by, granted_at
|
|
FROM access
|
|
WHERE entry_id = ?
|
|
ORDER BY granted_at DESC
|
|
`
|
|
|
|
rows, err := e.db.Query(query, entryID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var accessList []*model.Access
|
|
for rows.Next() {
|
|
access := &model.Access{}
|
|
var grantedAt int64
|
|
|
|
err := rows.Scan(&access.ID, &access.EntryID, &access.UserID,
|
|
&access.Permissions, &access.GrantedBy, &grantedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
access.GrantedAt = model.TimeFromUnix(grantedAt)
|
|
accessList = append(accessList, access)
|
|
}
|
|
|
|
return accessList, rows.Err()
|
|
}
|
|
|
|
// GetUserEntries returns entries accessible to a user with specified permissions
|
|
func (e *Engine) GetUserEntries(userID string, entryType string, permission int) ([]*model.Entry, error) {
|
|
// Admins can see everything
|
|
if isAdmin, err := e.isUserAdmin(userID); err != nil {
|
|
return nil, err
|
|
} else if isAdmin {
|
|
return e.getAllEntriesByType(entryType)
|
|
}
|
|
|
|
// Get entries with direct access
|
|
query := `
|
|
SELECT DISTINCT e.id, e.parent_id, e.deal_room_id, e.entry_type, e.title,
|
|
e.content, e.file_path, e.file_size, e.file_hash,
|
|
e.created_by, e.created_at, e.updated_at
|
|
FROM entries e
|
|
JOIN access a ON e.id = a.entry_id
|
|
WHERE a.user_id = ? AND (a.permissions & ?) > 0
|
|
AND ($1 = '' OR e.entry_type = $1)
|
|
ORDER BY e.created_at DESC
|
|
`
|
|
|
|
rows, err := e.db.Query(query, userID, permission, entryType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return e.scanEntries(rows)
|
|
}
|
|
|
|
// Helper methods
|
|
|
|
func (e *Engine) isUserAdmin(userID string) (bool, error) {
|
|
query := `SELECT role FROM users WHERE id = ? AND is_active = 1`
|
|
var role string
|
|
err := e.db.QueryRow(query, userID).Scan(&role)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return role == "admin", nil
|
|
}
|
|
|
|
func (e *Engine) getUserAccess(userID, entryID string) (*model.Access, error) {
|
|
query := `
|
|
SELECT id, entry_id, user_id, permissions, granted_by, granted_at
|
|
FROM access
|
|
WHERE user_id = ? AND entry_id = ?
|
|
`
|
|
|
|
access := &model.Access{}
|
|
var grantedAt int64
|
|
|
|
err := e.db.QueryRow(query, userID, entryID).Scan(
|
|
&access.ID, &access.EntryID, &access.UserID,
|
|
&access.Permissions, &access.GrantedBy, &grantedAt)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
access.GrantedAt = model.TimeFromUnix(grantedAt)
|
|
return access, nil
|
|
}
|
|
|
|
func (e *Engine) checkInheritedAccess(userID, entryID string, permission int) (bool, error) {
|
|
// Get the deal room ID for this entry
|
|
query := `SELECT deal_room_id FROM entries WHERE id = ?`
|
|
var dealRoomID string
|
|
err := e.db.QueryRow(query, entryID).Scan(&dealRoomID)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
// If this is already a deal room, no inheritance
|
|
if dealRoomID == entryID {
|
|
return false, nil
|
|
}
|
|
|
|
// Check access to the deal room
|
|
return e.CheckAccess(userID, dealRoomID, permission)
|
|
}
|
|
|
|
func (e *Engine) getAllEntriesByType(entryType string) ([]*model.Entry, error) {
|
|
query := `
|
|
SELECT id, parent_id, deal_room_id, entry_type, title,
|
|
content, file_path, file_size, file_hash,
|
|
created_by, created_at, updated_at
|
|
FROM entries
|
|
WHERE ($1 = '' OR entry_type = $1)
|
|
ORDER BY created_at DESC
|
|
`
|
|
|
|
rows, err := e.db.Query(query, entryType)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
return e.scanEntries(rows)
|
|
}
|
|
|
|
func (e *Engine) scanEntries(rows *sql.Rows) ([]*model.Entry, error) {
|
|
var entries []*model.Entry
|
|
|
|
for rows.Next() {
|
|
entry := &model.Entry{}
|
|
var createdAt, updatedAt int64
|
|
var parentID, filePath, fileSize, fileHash sql.NullString
|
|
var fileSizeInt sql.NullInt64
|
|
|
|
err := rows.Scan(&entry.ID, &parentID, &entry.DealRoomID, &entry.EntryType,
|
|
&entry.Title, &entry.Content, &filePath, &fileSizeInt, &fileHash,
|
|
&entry.CreatedBy, &createdAt, &updatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if parentID.Valid {
|
|
entry.ParentID = &parentID.String
|
|
}
|
|
if filePath.Valid {
|
|
entry.FilePath = &filePath.String
|
|
}
|
|
if fileSizeInt.Valid {
|
|
entry.FileSize = &fileSizeInt.Int64
|
|
}
|
|
if fileHash.Valid {
|
|
entry.FileHash = &fileHash.String
|
|
}
|
|
|
|
entry.CreatedAt = model.TimeFromUnix(createdAt)
|
|
entry.UpdatedAt = model.TimeFromUnix(updatedAt)
|
|
|
|
entries = append(entries, entry)
|
|
}
|
|
|
|
return entries, rows.Err()
|
|
} |