mithril-go/internal/stm/merkle.go
Kayos 9d6c7cffbe v1.0.1: audit fixes — fetchCertRaw status check, .part cleanup, AVK guards, strict merkle, JSON error envelope
Independent code audit (in-repo, fresh-eyes pass) flagged 0 critical, 4
high, 8 medium, 7 low. This commit addresses all 4 highs + the JSON
error-path inconsistency + the vestigial verify.STM stub.

HIGH fixes:
- cmd/mithril-go/main.go fetchCertRaw: missing status check let HTML 4xx/5xx
  bodies fall through to confusing JSON-decode errors. Added explicit
  StatusCode>=400 check + 16 MiB response body cap + Accept header.
- internal/artifact/download.go: SHA mismatch left .part on disk, causing
  every retry to resume the corrupted bytes and fail SHA forever. Now
  removes .part on hash mismatch so the next attempt starts clean.
- internal/stm/types.go DecodeAVK: rejects total_stake=0 and nr_leaves=0
  at decode-time. internal/stm/lottery.go adds defensive guard for
  stake==0 || totalStake==0 to prevent big.Rat.SetFrac panic (DoS vector
  for the MCP server when fed crafted AVK).
- internal/stm/merkle.go: now requires (a) every proof value is exactly
  32 bytes, (b) indices are STRICTLY ascending (no duplicates),
  (c) every index is < nr_leaves, (d) all proof values are consumed by
  the algorithm. Prevents parser-differential bugs vs upstream Rust.

JSON error-path wiring:
- cmd/mithril-go/json.go: replaced unused emitJSONErr with failure() helper
  that routes errors to stdout-as-JSON when -json is set, else stderr-as-text.
  Error envelope shape: {error: {code, kind, message}} where 'kind' is a
  stable short string (network/integrity/verify/usage/internal) for agents
  to branch on without parsing human text.
- All -json-supporting commands (info, list, show, cert, verify+subcommands)
  now use failure() in error paths instead of bare fmt.Fprintln(stderr).
- Verified: 'verify -json deadbeef' on a bogus hash now emits valid JSON
  to stdout with exit=3, instead of empty stdout + text on stderr.

Vestigial code:
- internal/verify/verify.go: removed STM() stub + ErrSTMNotImplemented.
  Real STM verification has lived in internal/stm/verify.go since the
  crypto sprint; the stub was dead code from milestone-by-milestone work.

Verification (still all green):
- preprod chain: 90 certs, 1124 wins ✓
- mainnet head:  59 signers, 1972 wins ✓
- preprod head:   2 signers,   11 wins ✓
- preprod genesis: Ed25519 ✓
- JSON error envelope on bogus hash: well-formed JSON, exit=3
- internal/stm unit test: PASS

Audit findings deferred to v1.0.2+: bubble-sort in stm.Verify (medium,
perf only at scale); int-vs-uint64 truncation guards on 32-bit targets
(medium, won't bite on 64-bit); tar mode-bit masking (medium, low impact
since archives are from trusted aggregator); no User-Agent header on
aggregator requests (low, op nicety); MCP scanner silent stop on >10 MiB
line (low, defensive).
2026-04-23 17:30:34 -07:00

182 lines
5.7 KiB
Go

package stm
import (
"bytes"
"encoding/binary"
"fmt"
"golang.org/x/crypto/blake2b"
)
// Mithril's Merkle tree uses Blake2b-256 over leaf-encodings:
//
// leaf_bytes = vk_96 || stake_be_u64 (104 bytes)
// leaf_hash = Blake2b-256(leaf_bytes)
// internal = Blake2b-256(left_hash || right_hash)
// empty_sib = Blake2b-256(0x00)
//
// The tree is heap-indexed: root at 0, leaves at next_power_of_two(nr_leaves)-1
// through next_power_of_two(nr_leaves)-1 + nr_leaves - 1.
// blake2b256 returns Blake2b-256(data).
func blake2b256(data ...[]byte) []byte {
h, _ := blake2b.New256(nil)
for _, d := range data {
h.Write(d)
}
return h.Sum(nil)
}
// LeafBytes encodes a (vk, stake) pair as the 104-byte leaf value hashed
// into the Merkle tree.
func LeafBytes(vk []byte, stake uint64) []byte {
out := make([]byte, 104)
copy(out[:96], vk)
binary.BigEndian.PutUint64(out[96:], stake)
return out
}
// nextPowerOfTwo returns the smallest power of two >= n. 0 returns 1.
func nextPowerOfTwo(n int) int {
if n <= 1 {
return 1
}
p := 1
for p < n {
p <<= 1
}
return p
}
func mtParent(i int) int { return (i - 1) / 2 }
func mtSibling(i int) int {
if i%2 == 1 {
return i + 1
}
return i - 1
}
// VerifyMerkleBatch verifies a Mithril batch proof: a set of leaf values at
// the given indices are in the tree with the given root. Returns nil on
// success.
//
// Arguments:
// - root: 32-byte Merkle root
// - nrLeaves: total number of leaves in the tree (from the AVK commitment)
// - leafValues: for each proved leaf, its pre-hash bytes (vk||stake)
// - indices: the leaf indices (0-based, within the leaf range); must be
// sorted ascending and len must equal len(leafValues)
// - proofValues: the Merkle path nodes as provided in the batch proof's
// `values` field
//
// The algorithm walks layer-by-layer from leaves to root, consuming
// provided values as siblings when the claimed index's sibling is not
// itself a claimed leaf. Direct port of
// mithril-stm::membership_commitment::merkle_tree::commitment::verify_leaves_membership_from_batch_path.
func VerifyMerkleBatch(root []byte, nrLeaves int, leafValues [][]byte, indices []uint64, proofValues [][]byte) error {
if len(leafValues) != len(indices) {
return fmt.Errorf("leaves/indices count mismatch: %d vs %d", len(leafValues), len(indices))
}
if nrLeaves <= 0 {
return fmt.Errorf("nrLeaves must be positive, got %d", nrLeaves)
}
// Validate every proof node is a 32-byte BLAKE2b-256 digest. Anything
// shorter or longer is malformed and Rust would reject it.
for i, v := range proofValues {
if len(v) != 32 {
return fmt.Errorf("proof value [%d]: got %d bytes, want 32", i, len(v))
}
}
// Indices must be strictly ascending — duplicates would create
// double-claiming under the same leaf and the algorithm doesn't expect
// them. (Rust uses sort_unstable + equality compare against the input;
// equivalent to "non-decreasing" but doesn't reject equal-adjacent.
// We're stricter than upstream here on purpose.)
ordered := make([]int, len(indices))
for i, v := range indices {
if v >= uint64(nrLeaves) {
return fmt.Errorf("index [%d]=%d out of range (nr_leaves=%d)", i, v, nrLeaves)
}
ordered[i] = int(v)
}
for i := 1; i < len(ordered); i++ {
if ordered[i] <= ordered[i-1] {
return fmt.Errorf("indices not strictly ascending at [%d]: %v", i, indices)
}
}
npo2 := nextPowerOfTwo(nrLeaves)
nrNodes := nrLeaves + npo2 - 1
// Shift leaf positions into tree coordinates.
for i := range ordered {
ordered[i] += npo2 - 1
}
// Hash each leaf.
currentLayer := make([][]byte, len(leafValues))
for i, lv := range leafValues {
currentLayer[i] = blake2b256(lv)
}
values := append([][]byte(nil), proofValues...)
idx := ordered[0]
emptySiblingHash := blake2b256([]byte{0x00})
for idx > 0 {
newHashes := make([][]byte, 0, len(ordered))
newIndices := make([]int, 0, len(ordered))
i := 0
idx = mtParent(idx)
for i < len(ordered) {
newIndices = append(newIndices, mtParent(ordered[i]))
if ordered[i]&1 == 0 {
// Current is a RIGHT child — its sibling (LEFT) comes from proof values.
if len(values) == 0 {
return fmt.Errorf("proof truncated at ordered[%d]=%d (expected left sibling)", i, ordered[i])
}
sib := values[0]
values = values[1:]
newHashes = append(newHashes, blake2b256(sib, currentLayer[i]))
} else {
// Current is a LEFT child — sibling is RIGHT.
sib := mtSibling(ordered[i])
switch {
case i+1 < len(ordered) && ordered[i+1] == sib:
// Sibling is ALSO a claimed leaf already in currentLayer.
newHashes = append(newHashes, blake2b256(currentLayer[i], currentLayer[i+1]))
i++
case sib < nrNodes:
// Sibling not claimed but exists; take from proof.
if len(values) == 0 {
return fmt.Errorf("proof truncated at ordered[%d]=%d (expected right sibling)", i, ordered[i])
}
s := values[0]
values = values[1:]
newHashes = append(newHashes, blake2b256(currentLayer[i], s))
default:
// Right side is beyond tree — empty sibling.
newHashes = append(newHashes, blake2b256(currentLayer[i], emptySiblingHash))
}
}
i++
}
currentLayer = newHashes
ordered = newIndices
}
if len(currentLayer) != 1 {
return fmt.Errorf("verification ended with %d nodes, want 1", len(currentLayer))
}
// All proof values must be consumed. Trailing bytes mean the proof
// shipped extra nodes the algorithm didn't need — likely malformed
// or attacker-padded.
if len(values) > 0 {
return fmt.Errorf("proof has %d unconsumed values — malformed", len(values))
}
if !bytes.Equal(currentLayer[0], root) {
return fmt.Errorf("root mismatch: got %x, want %x", currentLayer[0], root)
}
return nil
}