vault1984/commercial/account/db.go

193 lines
4.5 KiB
Go

package main
import (
"database/sql"
"log"
"time"
_ "github.com/mattn/go-sqlite3"
)
var db *sql.DB
func initDB(path string) {
var err error
db, err = sql.Open("sqlite3", path+"?_journal=WAL&_busy_timeout=5000")
if err != nil {
log.Fatalf("db open: %v", err)
}
migrations := []string{
`CREATE TABLE IF NOT EXISTS accounts (
email TEXT PRIMARY KEY,
stripe_id TEXT DEFAULT '',
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS vaults (
vault_id TEXT PRIMARY KEY,
account_email TEXT NOT NULL REFERENCES accounts(email),
region TEXT NOT NULL,
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
email TEXT NOT NULL REFERENCES accounts(email),
expires_at TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)`,
`CREATE TABLE IF NOT EXISTS login_codes (
email TEXT PRIMARY KEY,
code TEXT NOT NULL,
expires_at TEXT NOT NULL
)`,
}
for _, m := range migrations {
if _, err := db.Exec(m); err != nil {
log.Fatalf("migration: %v", err)
}
}
log.Println(" database ready")
}
// Accounts
func accountGet(email string) (bool, string, error) {
var stripeID string
err := db.QueryRow("SELECT stripe_id FROM accounts WHERE email = ?", email).Scan(&stripeID)
if err == sql.ErrNoRows {
return false, "", nil
}
if err != nil {
return false, "", err
}
return true, stripeID, nil
}
func accountCreate(email string) error {
_, err := db.Exec("INSERT OR IGNORE INTO accounts (email) VALUES (?)", email)
return err
}
// Vaults
func vaultList(email string) ([]Vault, error) {
rows, err := db.Query(
"SELECT vault_id, region, expires_at, created_at FROM vaults WHERE account_email = ? ORDER BY created_at",
email,
)
if err != nil {
return nil, err
}
defer rows.Close()
var vaults []Vault
for rows.Next() {
var v Vault
if err := rows.Scan(&v.ID, &v.Region, &v.ExpiresAt, &v.CreatedAt); err != nil {
return nil, err
}
vaults = append(vaults, v)
}
return vaults, nil
}
func vaultCount(email string) (int, error) {
var n int
err := db.QueryRow("SELECT COUNT(*) FROM vaults WHERE account_email = ?", email).Scan(&n)
return n, err
}
func vaultCreate(id, email, region string) error {
expires := time.Now().AddDate(1, 0, 0).UTC().Format(time.RFC3339)
_, err := db.Exec(
"INSERT INTO vaults (vault_id, account_email, region, expires_at) VALUES (?, ?, ?, ?)",
id, email, region, expires,
)
return err
}
func vaultDelete(id, email string) error {
res, err := db.Exec("DELETE FROM vaults WHERE vault_id = ? AND account_email = ?", id, email)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return sql.ErrNoRows
}
return nil
}
// Sessions
func sessionCreate(email string) (string, error) {
token := randomToken(32)
expires := time.Now().Add(24 * time.Hour).UTC().Format(time.RFC3339)
_, err := db.Exec(
"INSERT INTO sessions (token, email, expires_at) VALUES (?, ?, ?)",
token, email, expires,
)
return token, err
}
func sessionGet(token string) (string, error) {
var email string
var expires string
err := db.QueryRow("SELECT email, expires_at FROM sessions WHERE token = ?", token).Scan(&email, &expires)
if err != nil {
return "", err
}
t, _ := time.Parse(time.RFC3339, expires)
if time.Now().After(t) {
db.Exec("DELETE FROM sessions WHERE token = ?", token)
return "", sql.ErrNoRows
}
return email, nil
}
func sessionDelete(token string) {
db.Exec("DELETE FROM sessions WHERE token = ?", token)
}
// Login codes
func loginCodeSet(email, code string) error {
expires := time.Now().Add(10 * time.Minute).UTC().Format(time.RFC3339)
_, err := db.Exec(
"INSERT OR REPLACE INTO login_codes (email, code, expires_at) VALUES (?, ?, ?)",
email, code, expires,
)
return err
}
func loginCodeVerify(email, code string) (bool, error) {
var stored string
var expires string
err := db.QueryRow("SELECT code, expires_at FROM login_codes WHERE email = ?", email).Scan(&stored, &expires)
if err != nil {
return false, nil
}
t, _ := time.Parse(time.RFC3339, expires)
if time.Now().After(t) {
db.Exec("DELETE FROM login_codes WHERE email = ?", email)
return false, nil
}
if stored != code {
return false, nil
}
db.Exec("DELETE FROM login_codes WHERE email = ?", email)
return true, nil
}
// Types
type Vault struct {
ID string `json:"vault_id"`
Region string `json:"region"`
ExpiresAt string `json:"expires_at"`
CreatedAt string `json:"created_at"`
}