328 lines
7.4 KiB
Go
328 lines
7.4 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/csv"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
"inou/lib"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
const dbPath = "/tank/inou/data/inou.db"
|
|
|
|
func main() {
|
|
if len(os.Args) < 2 {
|
|
fmt.Fprintln(os.Stderr, "Usage: dbquery [OPTIONS] <SQL>")
|
|
fmt.Fprintln(os.Stderr, " Runs SQL against inou.db, decrypts fields, outputs JSON (default).")
|
|
fmt.Fprintln(os.Stderr, "")
|
|
fmt.Fprintln(os.Stderr, "Options:")
|
|
fmt.Fprintln(os.Stderr, " -csv Output as CSV")
|
|
fmt.Fprintln(os.Stderr, " -table Output as formatted table")
|
|
fmt.Fprintln(os.Stderr, "")
|
|
fmt.Fprintln(os.Stderr, "Example: dbquery \"SELECT * FROM entries LIMIT 5\"")
|
|
os.Exit(1)
|
|
}
|
|
|
|
args := os.Args[1:]
|
|
|
|
if len(args) > 0 && (args[0] == "-serve" || args[0] == "--serve") {
|
|
serveHTTP()
|
|
return
|
|
}
|
|
|
|
format := "json"
|
|
if len(args) > 0 && (args[0] == "-csv" || args[0] == "--csv") {
|
|
format = "csv"
|
|
args = args[1:]
|
|
} else if len(args) > 0 && (args[0] == "-table" || args[0] == "--table") {
|
|
format = "table"
|
|
args = args[1:]
|
|
}
|
|
|
|
if len(args) == 0 {
|
|
fmt.Fprintln(os.Stderr, "Error: SQL query required")
|
|
os.Exit(1)
|
|
}
|
|
|
|
query := strings.Join(args, " ")
|
|
|
|
if err := lib.CryptoInit(lib.KeyPathDefault); err != nil {
|
|
fmt.Fprintf(os.Stderr, "crypto init: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite3", dbPath)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "db open: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer db.Close()
|
|
|
|
cols, results, err := queryDB(db, query)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "query: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
switch format {
|
|
case "csv":
|
|
outputCSV(cols, results)
|
|
case "table":
|
|
outputTable(cols, results)
|
|
default:
|
|
out, _ := json.MarshalIndent(results, "", " ")
|
|
fmt.Println(string(out))
|
|
}
|
|
}
|
|
|
|
func queryDB(db *sql.DB, query string) ([]string, []map[string]interface{}, error) {
|
|
rows, err := db.Query(query)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
cols, _ := rows.Columns()
|
|
var results []map[string]interface{}
|
|
|
|
for rows.Next() {
|
|
vals := make([]interface{}, len(cols))
|
|
ptrs := make([]interface{}, len(cols))
|
|
for i := range vals {
|
|
ptrs[i] = &vals[i]
|
|
}
|
|
if err := rows.Scan(ptrs...); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
row := make(map[string]interface{})
|
|
for i, col := range cols {
|
|
v := vals[i]
|
|
switch val := v.(type) {
|
|
case []byte:
|
|
if unpacked := lib.Unpack(val); unpacked != nil {
|
|
s := string(unpacked)
|
|
if strings.HasPrefix(s, "{") || strings.HasPrefix(s, "[") {
|
|
var parsed interface{}
|
|
if json.Unmarshal(unpacked, &parsed) == nil {
|
|
row[col] = parsed
|
|
continue
|
|
}
|
|
}
|
|
row[col] = s
|
|
continue
|
|
}
|
|
s := string(val)
|
|
decrypted := s
|
|
for j := 0; j < 10; j++ {
|
|
next := lib.CryptoDecrypt(decrypted)
|
|
if next == "" || next == decrypted {
|
|
break
|
|
}
|
|
decrypted = next
|
|
}
|
|
if decrypted != s {
|
|
if strings.HasPrefix(decrypted, "{") || strings.HasPrefix(decrypted, "[") {
|
|
var parsed interface{}
|
|
if json.Unmarshal([]byte(decrypted), &parsed) == nil {
|
|
row[col] = parsed
|
|
continue
|
|
}
|
|
}
|
|
row[col] = decrypted
|
|
} else {
|
|
row[col] = s
|
|
}
|
|
case nil:
|
|
row[col] = nil
|
|
case string:
|
|
decrypted := val
|
|
for j := 0; j < 10; j++ {
|
|
next := lib.CryptoDecrypt(decrypted)
|
|
if next == "" || next == decrypted {
|
|
break
|
|
}
|
|
decrypted = next
|
|
}
|
|
if decrypted != val {
|
|
if strings.HasPrefix(decrypted, "{") || strings.HasPrefix(decrypted, "[") {
|
|
var parsed interface{}
|
|
if json.Unmarshal([]byte(decrypted), &parsed) == nil {
|
|
row[col] = parsed
|
|
continue
|
|
}
|
|
}
|
|
row[col] = decrypted
|
|
} else {
|
|
row[col] = val
|
|
}
|
|
default:
|
|
row[col] = v
|
|
}
|
|
}
|
|
results = append(results, row)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return cols, results, nil
|
|
}
|
|
|
|
func formatValue(val interface{}) string {
|
|
if val == nil {
|
|
return ""
|
|
}
|
|
switch val.(type) {
|
|
case map[string]interface{}, []interface{}:
|
|
if b, err := json.Marshal(val); err == nil {
|
|
return string(b)
|
|
}
|
|
}
|
|
return fmt.Sprintf("%v", val)
|
|
}
|
|
|
|
func outputCSV(cols []string, results []map[string]interface{}) {
|
|
w := csv.NewWriter(os.Stdout)
|
|
defer w.Flush()
|
|
w.Write(cols)
|
|
for _, row := range results {
|
|
record := make([]string, len(cols))
|
|
for i, col := range cols {
|
|
record[i] = formatValue(row[col])
|
|
}
|
|
w.Write(record)
|
|
}
|
|
}
|
|
|
|
func outputTable(cols []string, results []map[string]interface{}) {
|
|
if len(results) == 0 {
|
|
fmt.Println("(no rows)")
|
|
return
|
|
}
|
|
|
|
widths := make([]int, len(cols))
|
|
for i, col := range cols {
|
|
widths[i] = len(col)
|
|
}
|
|
for _, row := range results {
|
|
for i, col := range cols {
|
|
n := len(formatValue(row[col]))
|
|
if n > widths[i] {
|
|
widths[i] = n
|
|
}
|
|
}
|
|
}
|
|
for i := range widths {
|
|
if widths[i] > 50 {
|
|
widths[i] = 50
|
|
}
|
|
}
|
|
|
|
for i, col := range cols {
|
|
fmt.Printf("%-*s ", widths[i], col)
|
|
}
|
|
fmt.Println()
|
|
for i := range cols {
|
|
fmt.Printf("%s ", strings.Repeat("─", widths[i]))
|
|
}
|
|
fmt.Println()
|
|
|
|
for _, row := range results {
|
|
for i, col := range cols {
|
|
val := formatValue(row[col])
|
|
if len(val) > widths[i] {
|
|
val = val[:widths[i]-3] + "..."
|
|
}
|
|
fmt.Printf("%-*s ", widths[i], val)
|
|
}
|
|
fmt.Println()
|
|
}
|
|
}
|
|
|
|
const stagingIP = "192.168.1.253"
|
|
|
|
func isStaging() bool {
|
|
addrs, err := net.InterfaceAddrs()
|
|
if err != nil {
|
|
return false
|
|
}
|
|
for _, a := range addrs {
|
|
if ipnet, ok := a.(*net.IPNet); ok && ipnet.IP.String() == stagingIP {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func serveHTTP() {
|
|
if !isStaging() {
|
|
fmt.Fprintln(os.Stderr, "dbquery -serve: refused (not staging)")
|
|
os.Exit(1)
|
|
}
|
|
|
|
if err := lib.CryptoInit(lib.KeyPathDefault); err != nil {
|
|
fmt.Fprintf(os.Stderr, "crypto init: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
db, err := sql.Open("sqlite3", dbPath)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "db open: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
http.HandleFunc("/query", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
var req struct{ SQL string `json:"sql"` }
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.SQL == "" {
|
|
http.Error(w, `{"error":"sql required"}`, http.StatusBadRequest)
|
|
return
|
|
}
|
|
_, results, err := queryDB(db, req.SQL)
|
|
if err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(results)
|
|
})
|
|
|
|
http.HandleFunc("/exec", func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
http.Error(w, "POST only", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
var req struct{ SQL string `json:"sql"` }
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil || req.SQL == "" {
|
|
http.Error(w, `{"error":"sql required"}`, http.StatusBadRequest)
|
|
return
|
|
}
|
|
result, err := db.Exec(req.SQL)
|
|
if err != nil {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
|
|
return
|
|
}
|
|
affected, _ := result.RowsAffected()
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]int64{"affected": affected})
|
|
})
|
|
|
|
log.Printf("dbquery serving on :9124 (staging only)")
|
|
log.Fatal(http.ListenAndServe(":9124", nil))
|
|
}
|