clavitor/clavis/clavis-vault/lib/cvt.go

230 lines
6.4 KiB
Go

package lib
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
"math/big"
"strings"
"golang.org/x/crypto/hkdf"
)
// CVT record types.
//
// Only the wire token (0x00) is handled by the vault server. The client
// credential type (0x01, L2-bearing) is implemented in the C CLI
// (clavis-cli/src/cvt.c) and never decrypted by Go code — L2 is a hard veto
// for the server.
const (
CVTWireToken byte = 0x00 // Sent to vault: L1(8) + agent_id(16)
CVTCredentialType byte = 0x01 // Client credential: L2(16) + agent_id(16) + POP(4)
)
const cvtPrefix = "cvt_"
var (
ErrInvalidCVT = errors.New("invalid cvt token")
ErrCVTDecrypt = errors.New("cvt decryption failed")
ErrCVTBadType = errors.New("unexpected cvt record type")
)
// ---------------------------------------------------------------------------
// Minting
// ---------------------------------------------------------------------------
// MintWireToken creates a type 0x00 wire token.
// Embeds L1 (8 bytes) + agent_id (16 bytes), encrypted with L0 (4 bytes).
func MintWireToken(l0, l1, agentID []byte) (string, error) {
if len(l0) != 4 || len(l1) != 8 || len(agentID) != 16 {
return "", fmt.Errorf("bad lengths: l0=%d l1=%d agent_id=%d", len(l0), len(l1), len(agentID))
}
payload := make([]byte, 24)
copy(payload[0:8], l1)
copy(payload[8:24], agentID)
return cvtEncode(CVTWireToken, l0, payload)
}
// ---------------------------------------------------------------------------
// Parsing
// ---------------------------------------------------------------------------
// ParseWireToken decrypts a type 0x00 wire token.
// Returns L0 (4 bytes), L1 (8 bytes), and agent_id (16 bytes).
func ParseWireToken(token string) (l0, l1, agentID []byte, err error) {
typ, l0, payload, err := cvtDecode(token)
if err != nil {
return nil, nil, nil, err
}
if typ != CVTWireToken {
return nil, nil, nil, ErrCVTBadType
}
if len(payload) != 24 {
return nil, nil, nil, fmt.Errorf("wire payload: got %d bytes, want 24", len(payload))
}
return l0, payload[0:8], payload[8:24], nil
}
// MintCredential creates a type 0x01 client credential token (for testing).
// This simulates client-side credential generation that normally happens in browser/CLI.
// Payload: L2(16) + agent_id(16) + POP(4) = 36 bytes, encrypted with L0.
func MintCredential(l0, l2, agentID, pop []byte) (string, error) {
if len(l0) != 4 || len(l2) != 16 || len(agentID) != 16 || len(pop) != 4 {
return "", fmt.Errorf("bad lengths: l0=%d l2=%d agent_id=%d pop=%d", len(l0), len(l2), len(agentID), len(pop))
}
payload := make([]byte, 36)
copy(payload[0:16], l2)
copy(payload[16:32], agentID)
copy(payload[32:36], pop)
return cvtEncode(CVTCredentialType, l0, payload)
}
// ParseCredential decrypts a type 0x01 client credential token (for testing).
// Returns L0 (4 bytes), L2 (16 bytes), agent_id (16 bytes), and POP (4 bytes).
func ParseCredential(token string) (l0, l2, agentID, pop []byte, err error) {
typ, l0, payload, err := cvtDecode(token)
if err != nil {
return nil, nil, nil, nil, err
}
if typ != CVTCredentialType {
return nil, nil, nil, nil, ErrCVTBadType
}
if len(payload) != 36 {
return nil, nil, nil, nil, fmt.Errorf("credential payload: got %d bytes, want 36", len(payload))
}
return l0, payload[0:16], payload[16:32], payload[32:36], nil
}
// ---------------------------------------------------------------------------
// CVT envelope: type(1) + L0(4) + AES-GCM(derived(L0), payload)
// ---------------------------------------------------------------------------
func cvtEncode(typ byte, l0, payload []byte) (string, error) {
key := cvtDeriveKey(l0)
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", err
}
ciphertext := gcm.Seal(nonce, nonce, payload, nil)
// Assemble: type(1) + L0(4) + ciphertext
buf := make([]byte, 1+4+len(ciphertext))
buf[0] = typ
copy(buf[1:5], l0)
copy(buf[5:], ciphertext)
return cvtPrefix + base62Encode(buf), nil
}
func cvtDecode(token string) (typ byte, l0, payload []byte, err error) {
if !strings.HasPrefix(token, cvtPrefix) {
return 0, nil, nil, ErrInvalidCVT
}
raw, err := base62Decode(strings.TrimPrefix(token, cvtPrefix))
if err != nil {
return 0, nil, nil, ErrInvalidCVT
}
if len(raw) < 5 {
return 0, nil, nil, ErrInvalidCVT
}
typ = raw[0]
l0 = raw[1:5]
ciphertext := raw[5:]
key := cvtDeriveKey(l0)
block, err := aes.NewCipher(key)
if err != nil {
return 0, nil, nil, ErrCVTDecrypt
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return 0, nil, nil, ErrCVTDecrypt
}
if len(ciphertext) < gcm.NonceSize() {
return 0, nil, nil, ErrCVTDecrypt
}
nonce := ciphertext[:gcm.NonceSize()]
ct := ciphertext[gcm.NonceSize():]
payload, err = gcm.Open(nil, nonce, ct, nil)
if err != nil {
return 0, nil, nil, ErrCVTDecrypt
}
return typ, l0, payload, nil
}
// cvtDeriveKey derives a 16-byte AES-128 key from L0 (4 bytes) via HKDF.
func cvtDeriveKey(l0 []byte) []byte {
reader := hkdf.New(sha256.New, l0, nil, []byte("cvt-envelope"))
key := make([]byte, 16)
io.ReadFull(reader, key)
return key
}
// ---------------------------------------------------------------------------
// Base62
// ---------------------------------------------------------------------------
const base62Chars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func base62Encode(data []byte) string {
n := new(big.Int).SetBytes(data)
base := big.NewInt(62)
zero := big.NewInt(0)
mod := new(big.Int)
var chars []byte
for n.Cmp(zero) > 0 {
n.DivMod(n, base, mod)
chars = append(chars, base62Chars[mod.Int64()])
}
for _, b := range data {
if b != 0 {
break
}
chars = append(chars, base62Chars[0])
}
for i, j := 0, len(chars)-1; i < j; i, j = i+1, j-1 {
chars[i], chars[j] = chars[j], chars[i]
}
return string(chars)
}
func base62Decode(s string) ([]byte, error) {
n := new(big.Int)
base := big.NewInt(62)
for _, c := range s {
idx := strings.IndexRune(base62Chars, c)
if idx < 0 {
return nil, fmt.Errorf("invalid base62 character: %c", c)
}
n.Mul(n, base)
n.Add(n, big.NewInt(int64(idx)))
}
b := n.Bytes()
leadingZeros := 0
for _, c := range s {
if c == rune(base62Chars[0]) {
leadingZeros++
} else {
break
}
}
result := make([]byte, leadingZeros+len(b))
copy(result[leadingZeros:], b)
return result, nil
}