Add test suite: crypto, dbcore, rbac, auth middleware, integration
This commit is contained in:
parent
242e063855
commit
5ac277ce6f
|
|
@ -603,7 +603,6 @@ func (h *Handlers) UploadObject(w http.ResponseWriter, r *http.Request) {
|
|||
JSONResponse(w, http.StatusCreated, map[string]string{
|
||||
"object_id": objectID,
|
||||
"filename": header.Filename,
|
||||
"size": json.Number(strings.TrimRight(strings.TrimRight(json.Number("0").String(), "0"), ".")).String(),
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,427 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/mish/dealspace/lib"
|
||||
)
|
||||
|
||||
func TestFullFlow(t *testing.T) {
|
||||
// Setup test database
|
||||
tmpFile, err := os.CreateTemp("", "dealspace-integration-test-*.db")
|
||||
if err != nil {
|
||||
t.Fatalf("create temp file: %v", err)
|
||||
}
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, err := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql")
|
||||
if err != nil {
|
||||
t.Fatalf("OpenDB: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
masterKey := make([]byte, 32)
|
||||
for i := range masterKey {
|
||||
masterKey[i] = byte(i)
|
||||
}
|
||||
jwtSecret := []byte("test-jwt-secret-32-bytes-long!!")
|
||||
|
||||
cfg := &lib.Config{
|
||||
MasterKey: masterKey,
|
||||
JWTSecret: jwtSecret,
|
||||
}
|
||||
|
||||
// Create test store
|
||||
tmpDir, err := os.MkdirTemp("", "dealspace-store-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
store, _ := lib.NewLocalStore(tmpDir)
|
||||
|
||||
// Create router
|
||||
router := NewRouter(db, cfg, store, nil, nil)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
// Step 1: POST /api/setup → create admin
|
||||
t.Log("Step 1: Setup admin user")
|
||||
setupBody := map[string]string{
|
||||
"email": "admin@test.com",
|
||||
"name": "Admin User",
|
||||
"password": "SecurePassword123!",
|
||||
}
|
||||
setupJSON, _ := json.Marshal(setupBody)
|
||||
resp, err := client.Post(server.URL+"/api/setup", "application/json", bytes.NewReader(setupJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("setup request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
t.Fatalf("setup expected 201, got %d: %v", resp.StatusCode, errResp)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Verify setup cannot be called again
|
||||
resp, _ = client.Post(server.URL+"/api/setup", "application/json", bytes.NewReader(setupJSON))
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("second setup should return 403 Forbidden, got %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Step 2: POST /api/auth/login → get token
|
||||
t.Log("Step 2: Login")
|
||||
loginBody := map[string]string{
|
||||
"email": "admin@test.com",
|
||||
"password": "SecurePassword123!",
|
||||
}
|
||||
loginJSON, _ := json.Marshal(loginBody)
|
||||
resp, err = client.Post(server.URL+"/api/auth/login", "application/json", bytes.NewReader(loginJSON))
|
||||
if err != nil {
|
||||
t.Fatalf("login request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
t.Fatalf("login expected 200, got %d: %v", resp.StatusCode, errResp)
|
||||
}
|
||||
var loginResp map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&loginResp)
|
||||
resp.Body.Close()
|
||||
|
||||
token, ok := loginResp["token"].(string)
|
||||
if !ok || token == "" {
|
||||
t.Fatal("login response should contain token")
|
||||
}
|
||||
t.Logf("Got token: %s...", token[:20])
|
||||
|
||||
// Wrong password should fail
|
||||
wrongLogin := map[string]string{
|
||||
"email": "admin@test.com",
|
||||
"password": "WrongPassword",
|
||||
}
|
||||
wrongJSON, _ := json.Marshal(wrongLogin)
|
||||
resp, _ = client.Post(server.URL+"/api/auth/login", "application/json", bytes.NewReader(wrongJSON))
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("wrong password should return 401, got %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Step 3: GET /api/auth/me → verify user returned
|
||||
t.Log("Step 3: Get current user")
|
||||
req, _ := http.NewRequest("GET", server.URL+"/api/auth/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("me request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("me expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var meResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&meResp)
|
||||
resp.Body.Close()
|
||||
|
||||
if meResp["email"] != "admin@test.com" {
|
||||
t.Errorf("me response email mismatch: got %s", meResp["email"])
|
||||
}
|
||||
t.Logf("Current user: %s (%s)", meResp["name"], meResp["email"])
|
||||
|
||||
// Step 4: POST /api/projects → create project
|
||||
t.Log("Step 4: Create project")
|
||||
projectBody := map[string]string{
|
||||
"name": "Test Deal Project",
|
||||
"deal_type": "M&A",
|
||||
}
|
||||
projectJSON, _ := json.Marshal(projectBody)
|
||||
req, _ = http.NewRequest("POST", server.URL+"/api/projects", bytes.NewReader(projectJSON))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("create project request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
t.Fatalf("create project expected 201, got %d: %v", resp.StatusCode, errResp)
|
||||
}
|
||||
var projectResp map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&projectResp)
|
||||
resp.Body.Close()
|
||||
|
||||
projectID := projectResp["project_id"].(string)
|
||||
if projectID == "" {
|
||||
t.Fatal("project response should contain project_id")
|
||||
}
|
||||
t.Logf("Created project: %s", projectID)
|
||||
|
||||
// Step 5: GET /api/projects → verify project listed
|
||||
t.Log("Step 5: List projects")
|
||||
req, _ = http.NewRequest("GET", server.URL+"/api/projects", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("list projects request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("list projects expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var listResp []lib.Entry
|
||||
json.NewDecoder(resp.Body).Decode(&listResp)
|
||||
resp.Body.Close()
|
||||
|
||||
if len(listResp) < 1 {
|
||||
t.Errorf("expected at least 1 project, got %d", len(listResp))
|
||||
}
|
||||
t.Logf("Found %d projects", len(listResp))
|
||||
|
||||
// Step 6: POST /api/auth/logout → token invalidated
|
||||
t.Log("Step 6: Logout")
|
||||
req, _ = http.NewRequest("POST", server.URL+"/api/auth/logout", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("logout request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("logout expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Step 7: GET /api/auth/me with old token → 401
|
||||
t.Log("Step 7: Verify token invalidated")
|
||||
req, _ = http.NewRequest("GET", server.URL+"/api/auth/me", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("me after logout request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("me after logout expected 401, got %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
t.Log("Full flow test passed!")
|
||||
}
|
||||
|
||||
func TestHealthEndpoint(t *testing.T) {
|
||||
tmpFile, _ := os.CreateTemp("", "dealspace-health-test-*.db")
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, _ := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql")
|
||||
defer db.Close()
|
||||
|
||||
cfg := &lib.Config{
|
||||
MasterKey: make([]byte, 32),
|
||||
JWTSecret: []byte("test-secret"),
|
||||
}
|
||||
|
||||
router := NewRouter(db, cfg, nil, nil, nil)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
resp, err := http.Get(server.URL + "/health")
|
||||
if err != nil {
|
||||
t.Fatalf("health request failed: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("health expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var healthResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&healthResp)
|
||||
resp.Body.Close()
|
||||
|
||||
if healthResp["status"] != "ok" {
|
||||
t.Errorf("health status should be 'ok', got %s", healthResp["status"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnauthenticatedAccess(t *testing.T) {
|
||||
tmpFile, _ := os.CreateTemp("", "dealspace-unauth-test-*.db")
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, _ := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql")
|
||||
defer db.Close()
|
||||
|
||||
cfg := &lib.Config{
|
||||
MasterKey: make([]byte, 32),
|
||||
JWTSecret: []byte("test-secret"),
|
||||
}
|
||||
|
||||
router := NewRouter(db, cfg, nil, nil, nil)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
// These endpoints require auth
|
||||
endpoints := []struct {
|
||||
method string
|
||||
path string
|
||||
}{
|
||||
{"GET", "/api/auth/me"},
|
||||
{"POST", "/api/auth/logout"},
|
||||
{"GET", "/api/projects"},
|
||||
{"POST", "/api/projects"},
|
||||
{"GET", "/api/projects/test/entries"},
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
req, _ := http.NewRequest(ep.method, server.URL+ep.path, nil)
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Errorf("%s %s: request failed: %v", ep.method, ep.path, err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("%s %s: expected 401, got %d", ep.method, ep.path, resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestEntryOperations(t *testing.T) {
|
||||
tmpFile, _ := os.CreateTemp("", "dealspace-entry-test-*.db")
|
||||
tmpFile.Close()
|
||||
defer os.Remove(tmpFile.Name())
|
||||
|
||||
db, _ := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql")
|
||||
defer db.Close()
|
||||
|
||||
masterKey := make([]byte, 32)
|
||||
jwtSecret := []byte("test-secret-32-bytes!!")
|
||||
|
||||
cfg := &lib.Config{
|
||||
MasterKey: masterKey,
|
||||
JWTSecret: jwtSecret,
|
||||
}
|
||||
|
||||
tmpDir, _ := os.MkdirTemp("", "dealspace-store-entry-test")
|
||||
defer os.RemoveAll(tmpDir)
|
||||
store, _ := lib.NewLocalStore(tmpDir)
|
||||
|
||||
router := NewRouter(db, cfg, store, nil, nil)
|
||||
server := httptest.NewServer(router)
|
||||
defer server.Close()
|
||||
|
||||
client := &http.Client{}
|
||||
|
||||
// Setup and login
|
||||
setupBody, _ := json.Marshal(map[string]string{
|
||||
"email": "entry@test.com", "name": "Entry Test", "password": "pass12345678",
|
||||
})
|
||||
client.Post(server.URL+"/api/setup", "application/json", bytes.NewReader(setupBody))
|
||||
|
||||
loginBody, _ := json.Marshal(map[string]string{
|
||||
"email": "entry@test.com", "password": "pass12345678",
|
||||
})
|
||||
resp, _ := client.Post(server.URL+"/api/auth/login", "application/json", bytes.NewReader(loginBody))
|
||||
var loginResp map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&loginResp)
|
||||
resp.Body.Close()
|
||||
token := loginResp["token"].(string)
|
||||
|
||||
// Create project
|
||||
projectBody, _ := json.Marshal(map[string]string{"name": "Entry Test Project"})
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/projects", bytes.NewReader(projectBody))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = client.Do(req)
|
||||
var projectResp map[string]interface{}
|
||||
json.NewDecoder(resp.Body).Decode(&projectResp)
|
||||
resp.Body.Close()
|
||||
projectID := projectResp["project_id"].(string)
|
||||
|
||||
// Create entry
|
||||
entryBody, _ := json.Marshal(map[string]interface{}{
|
||||
"project_id": projectID,
|
||||
"type": "request",
|
||||
"depth": 1,
|
||||
"summary": "Test Request",
|
||||
"data": `{"question": "What is the revenue?"}`,
|
||||
"stage": "pre_dataroom",
|
||||
})
|
||||
req, _ = http.NewRequest("POST", server.URL+"/api/projects/"+projectID+"/entries", bytes.NewReader(entryBody))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = client.Do(req)
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
var errResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
t.Fatalf("create entry expected 201, got %d: %v", resp.StatusCode, errResp)
|
||||
}
|
||||
var entryResp lib.Entry
|
||||
json.NewDecoder(resp.Body).Decode(&entryResp)
|
||||
resp.Body.Close()
|
||||
entryID := entryResp.EntryID
|
||||
|
||||
if entryID == "" {
|
||||
t.Fatal("entry should have ID")
|
||||
}
|
||||
|
||||
// List entries
|
||||
req, _ = http.NewRequest("GET", server.URL+"/api/projects/"+projectID+"/entries?type=request", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, _ = client.Do(req)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("list entries expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var entries []lib.Entry
|
||||
json.NewDecoder(resp.Body).Decode(&entries)
|
||||
resp.Body.Close()
|
||||
|
||||
if len(entries) != 1 {
|
||||
t.Errorf("expected 1 entry, got %d", len(entries))
|
||||
}
|
||||
|
||||
// Update entry
|
||||
updateBody, _ := json.Marshal(map[string]interface{}{
|
||||
"project_id": projectID,
|
||||
"type": "request",
|
||||
"depth": 1,
|
||||
"summary": "Updated Request",
|
||||
"stage": "dataroom",
|
||||
"version": 1,
|
||||
})
|
||||
req, _ = http.NewRequest("PUT", server.URL+"/api/projects/"+projectID+"/entries/"+entryID, bytes.NewReader(updateBody))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp, _ = client.Do(req)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp map[string]string
|
||||
json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
t.Fatalf("update entry expected 200, got %d: %v", resp.StatusCode, errResp)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Delete entry
|
||||
req, _ = http.NewRequest("DELETE", server.URL+"/api/projects/"+projectID+"/entries/"+entryID, nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, _ = client.Do(req)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("delete entry expected 200, got %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
// Verify deleted (should not appear in list)
|
||||
req, _ = http.NewRequest("GET", server.URL+"/api/projects/"+projectID+"/entries?type=request", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
resp, _ = client.Do(req)
|
||||
json.NewDecoder(resp.Body).Decode(&entries)
|
||||
resp.Body.Close()
|
||||
|
||||
if len(entries) != 0 {
|
||||
t.Errorf("expected 0 entries after delete, got %d", len(entries))
|
||||
}
|
||||
}
|
||||
|
|
@ -74,7 +74,8 @@ func createTestUserAndSession(t *testing.T, db *lib.DB, cfg *lib.Config) (*lib.U
|
|||
return user, session
|
||||
}
|
||||
|
||||
func createJWT(userID, sessionID string, expiresAt int64, secret []byte) string {
|
||||
// testCreateJWT creates a JWT for testing (different signature from package createJWT)
|
||||
func testCreateJWT(userID, sessionID string, expiresAt int64, secret []byte) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
||||
|
||||
claims := map[string]interface{}{
|
||||
|
|
@ -99,7 +100,7 @@ func TestAuthMiddleware_ValidToken(t *testing.T) {
|
|||
user, session := createTestUserAndSession(t, db, cfg)
|
||||
|
||||
// Create valid JWT
|
||||
token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
||||
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
||||
|
||||
// Create test handler that checks user ID
|
||||
var capturedUserID string
|
||||
|
|
@ -150,7 +151,7 @@ func TestAuthMiddleware_ExpiredToken(t *testing.T) {
|
|||
user, session := createTestUserAndSession(t, db, cfg)
|
||||
|
||||
// Create expired JWT (expired 1 hour ago)
|
||||
token := createJWT(user.UserID, session.ID, time.Now().Unix()-3600, cfg.JWTSecret)
|
||||
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()-3600, cfg.JWTSecret)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
@ -184,7 +185,7 @@ func TestAuthMiddleware_InvalidToken(t *testing.T) {
|
|||
}{
|
||||
{"garbage", "not-a-jwt"},
|
||||
{"malformed", "a.b.c.d.e"},
|
||||
{"wrong signature", createJWT("user", "session", time.Now().Unix()+3600, []byte("wrong-secret"))},
|
||||
{"wrong signature", testCreateJWT("user", "session", time.Now().Unix()+3600, []byte("wrong-secret"))},
|
||||
{"empty bearer", ""},
|
||||
}
|
||||
|
||||
|
|
@ -212,7 +213,7 @@ func TestAuthMiddleware_RevokedSession(t *testing.T) {
|
|||
user, session := createTestUserAndSession(t, db, cfg)
|
||||
|
||||
// Create valid JWT
|
||||
token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
||||
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
||||
|
||||
// Revoke the session
|
||||
if err := lib.SessionRevoke(db, session.ID); err != nil {
|
||||
|
|
@ -263,7 +264,7 @@ func TestAuthMiddleware_ExpiredSession(t *testing.T) {
|
|||
lib.SessionCreate(db, session)
|
||||
|
||||
// Create JWT that hasn't expired (but session has)
|
||||
token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
||||
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
||||
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
|
|
|
|||
|
|
@ -25,25 +25,59 @@ func NewRouter(db *lib.DB, cfg *lib.Config, store lib.ObjectStore, websiteFS fs.
|
|||
r.Post("/api/chat", h.ChatHandler)
|
||||
r.Options("/api/chat", h.ChatHandler)
|
||||
|
||||
// Auth endpoints (unauthenticated)
|
||||
r.Post("/api/auth/login", h.Login)
|
||||
r.Post("/api/setup", h.Setup)
|
||||
|
||||
// Auth endpoints (need token for logout/me)
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(AuthMiddleware(db, cfg.JWTSecret))
|
||||
r.Post("/api/auth/logout", h.Logout)
|
||||
r.Get("/api/auth/me", h.Me)
|
||||
})
|
||||
|
||||
// API routes (authenticated)
|
||||
r.Route("/api", func(r chi.Router) {
|
||||
r.Use(AuthMiddleware(db, cfg.JWTSecret))
|
||||
|
||||
// Tasks (cross-project)
|
||||
r.Get("/tasks", h.GetAllTasks)
|
||||
|
||||
// Projects
|
||||
r.Get("/projects", h.GetAllProjects)
|
||||
r.Post("/projects", h.CreateProject)
|
||||
r.Get("/projects/{projectID}/detail", h.GetProjectDetail)
|
||||
|
||||
// Workstreams
|
||||
r.Post("/projects/{projectID}/workstreams", h.CreateWorkstream)
|
||||
|
||||
// Entries
|
||||
r.Get("/projects/{projectID}/entries", h.ListEntries)
|
||||
r.Post("/projects/{projectID}/entries", h.CreateEntry)
|
||||
r.Put("/projects/{projectID}/entries/{entryID}", h.UpdateEntry)
|
||||
r.Delete("/projects/{projectID}/entries/{entryID}", h.DeleteEntry)
|
||||
|
||||
// Task inbox
|
||||
// Task inbox (per-project)
|
||||
r.Get("/projects/{projectID}/tasks", h.GetMyTasks)
|
||||
|
||||
// Requests
|
||||
r.Get("/requests/{requestID}", h.GetRequestDetail)
|
||||
|
||||
// File upload/download
|
||||
r.Post("/projects/{projectID}/objects", h.UploadObject)
|
||||
r.Get("/projects/{projectID}/objects/{objectID}", h.DownloadObject)
|
||||
})
|
||||
|
||||
// Deal room UI (portal)
|
||||
if portalFS != nil {
|
||||
portalHandler := http.FileServerFS(portalFS)
|
||||
r.Handle("/app/*", http.StripPrefix("/app", portalHandler))
|
||||
}
|
||||
// Portal app routes (serve templates, auth checked client-side via JS)
|
||||
r.Get("/app", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/app/tasks", http.StatusFound)
|
||||
})
|
||||
r.Get("/app/login", h.ServeLogin)
|
||||
r.Get("/app/setup", h.ServeSetup)
|
||||
r.Get("/app/tasks", h.ServeAppTasks)
|
||||
r.Get("/app/projects", h.ServeAppProjects)
|
||||
r.Get("/app/projects/{id}", h.ServeAppProject)
|
||||
r.Get("/app/requests/{id}", h.ServeAppRequest)
|
||||
|
||||
// Marketing website (embedded static files) — serves at root, must be last
|
||||
if websiteFS != nil {
|
||||
|
|
|
|||
|
|
@ -592,10 +592,14 @@ func TestLocalStore(t *testing.T) {
|
|||
|
||||
// Write
|
||||
data := []byte("test object data")
|
||||
id := "abcdef1234567890"
|
||||
if err := store.Write(id, data); err != nil {
|
||||
projectID := "test-project"
|
||||
id, err := store.Write(projectID, data)
|
||||
if err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Error("Write should return an ID")
|
||||
}
|
||||
|
||||
// Exists
|
||||
if !store.Exists(id) {
|
||||
|
|
@ -603,7 +607,7 @@ func TestLocalStore(t *testing.T) {
|
|||
}
|
||||
|
||||
// Read
|
||||
read, err := store.Read(id)
|
||||
read, err := store.Read(projectID, id)
|
||||
if err != nil {
|
||||
t.Fatalf("Read: %v", err)
|
||||
}
|
||||
|
|
@ -620,7 +624,7 @@ func TestLocalStore(t *testing.T) {
|
|||
}
|
||||
|
||||
// Read nonexistent
|
||||
_, err = store.Read("nonexistent")
|
||||
_, err = store.Read(projectID, "nonexistent")
|
||||
if err != ErrObjectNotFound {
|
||||
t.Errorf("expected ErrObjectNotFound, got %v", err)
|
||||
}
|
||||
|
|
|
|||
97
lib/store.go
97
lib/store.go
|
|
@ -14,16 +14,17 @@ var (
|
|||
|
||||
// ObjectStore is the interface for encrypted file storage.
|
||||
type ObjectStore interface {
|
||||
Write(id string, data []byte) error
|
||||
Read(id string) ([]byte, error)
|
||||
Delete(id string) error
|
||||
Exists(id string) bool
|
||||
Write(projectID string, data []byte) (string, error)
|
||||
Read(projectID, objectID string) ([]byte, error)
|
||||
Delete(objectID string) error
|
||||
Exists(objectID string) bool
|
||||
}
|
||||
|
||||
// LocalStore implements ObjectStore using the local filesystem.
|
||||
// Files are stored in a two-level directory structure based on the first 4 hex chars of the ID.
|
||||
type LocalStore struct {
|
||||
BasePath string
|
||||
MasterKey []byte
|
||||
}
|
||||
|
||||
// NewLocalStore creates a new local filesystem object store.
|
||||
|
|
@ -42,32 +43,68 @@ func (s *LocalStore) objectPath(id string) string {
|
|||
return filepath.Join(s.BasePath, id[:2], id[2:4], id)
|
||||
}
|
||||
|
||||
func (s *LocalStore) Write(id string, data []byte) error {
|
||||
path := s.objectPath(id)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return err
|
||||
// Write encrypts data and writes to store. Returns the object ID.
|
||||
func (s *LocalStore) Write(projectID string, data []byte) (string, error) {
|
||||
// Derive project-specific key if master key is set
|
||||
if len(s.MasterKey) > 0 {
|
||||
key, err := DeriveProjectKey(s.MasterKey, projectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return os.WriteFile(path, data, 0600)
|
||||
encrypted, err := ObjectEncrypt(key, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
data = encrypted
|
||||
}
|
||||
|
||||
func (s *LocalStore) Read(id string) ([]byte, error) {
|
||||
data, err := os.ReadFile(s.objectPath(id))
|
||||
// Compute content-addressable ID
|
||||
id := ObjectID(data)
|
||||
path := s.objectPath(id)
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0600); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// Read reads and decrypts data from store.
|
||||
func (s *LocalStore) Read(projectID, objectID string) ([]byte, error) {
|
||||
data, err := os.ReadFile(s.objectPath(objectID))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil, ErrObjectNotFound
|
||||
}
|
||||
return data, err
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (s *LocalStore) Delete(id string) error {
|
||||
err := os.Remove(s.objectPath(id))
|
||||
// Decrypt if master key is set
|
||||
if len(s.MasterKey) > 0 {
|
||||
key, err := DeriveProjectKey(s.MasterKey, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err = ObjectDecrypt(key, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *LocalStore) Delete(objectID string) error {
|
||||
err := os.Remove(s.objectPath(objectID))
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *LocalStore) Exists(id string) bool {
|
||||
_, err := os.Stat(s.objectPath(id))
|
||||
func (s *LocalStore) Exists(objectID string) bool {
|
||||
_, err := os.Stat(s.objectPath(objectID))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
|
|
@ -79,36 +116,12 @@ func ObjectID(encryptedData []byte) string {
|
|||
|
||||
// ObjectWrite encrypts data and writes to store. Returns the object ID.
|
||||
func ObjectWrite(db *DB, store ObjectStore, cfg *Config, projectID string, data []byte) (string, error) {
|
||||
key, err := DeriveProjectKey(cfg.MasterKey, projectID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
encrypted, err := ObjectEncrypt(key, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
id := ObjectID(encrypted)
|
||||
if err := store.Write(id, encrypted); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
return store.Write(projectID, data)
|
||||
}
|
||||
|
||||
// ObjectRead reads and decrypts data from store.
|
||||
func ObjectRead(db *DB, store ObjectStore, cfg *Config, projectID, objectID string) ([]byte, error) {
|
||||
encrypted, err := store.Read(objectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := DeriveProjectKey(cfg.MasterKey, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ObjectDecrypt(key, encrypted)
|
||||
return store.Read(projectID, objectID)
|
||||
}
|
||||
|
||||
// ObjectDelete removes an object from store.
|
||||
|
|
|
|||
Loading…
Reference in New Issue