diff --git a/cmd/mithril-go/json.go b/cmd/mithril-go/json.go index b647a47..85a3315 100644 --- a/cmd/mithril-go/json.go +++ b/cmd/mithril-go/json.go @@ -3,7 +3,6 @@ package main import ( "encoding/json" "fmt" - "io" "os" ) @@ -19,12 +18,39 @@ func emitJSON(v any) int { return 0 } -// emitJSONErr writes a structured error envelope. Mirrors the shape -// Claude/MCP-friendly consumers want: {"error": {"code":..., "message":...}}. -func emitJSONErr(w io.Writer, code, msg string) { - enc := json.NewEncoder(w) +// emitJSONErr writes a structured error envelope to stdout in the shape +// MCP / agent consumers expect: +// +// {"error": {"code": "...", "message": "..."}} +// +// Returns the supplied exit code so callers can do `return emitJSONErr(...)`. +func emitJSONErr(code int, kind, msg string) int { + enc := json.NewEncoder(os.Stdout) enc.SetIndent("", " ") _ = enc.Encode(map[string]any{ - "error": map[string]string{"code": code, "message": msg}, + "error": map[string]any{ + "code": code, + "kind": kind, + "message": msg, + }, }) + return code +} + +// failure routes an error to either stdout-as-JSON (when the user passed +// -json) or stderr-as-text (default). Returns the supplied exit code. +// +// kind is a stable short string ("network", "integrity", "verify", +// "usage", "internal") — agents can branch on this without parsing +// human-readable text. +func failure(asJSON bool, code int, kind, prefix string, err error) int { + msg := err.Error() + if prefix != "" { + msg = prefix + ": " + msg + } + if asJSON { + return emitJSONErr(code, kind, msg) + } + fmt.Fprintln(os.Stderr, msg) + return code } diff --git a/cmd/mithril-go/main.go b/cmd/mithril-go/main.go index b253c61..635c755 100644 --- a/cmd/mithril-go/main.go +++ b/cmd/mithril-go/main.go @@ -16,6 +16,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "net/http" "os" "os/signal" @@ -34,7 +35,7 @@ import ( "git.sulkta.coop/Sulkta-Coop/mithril-go/internal/verify" ) -const version = "0.0.3-dev" +const version = "1.0.1" // Stable exit codes. Any addition goes at the end; existing values // don't renumber. LLM/automation-friendly contract. @@ -132,14 +133,12 @@ func cmdList(ctx context.Context, args []string) int { asJSON := fs.Bool("json", false, "emit structured JSON") n, _, err := resolveNetwork(fs, args) if err != nil { - fmt.Fprintln(os.Stderr, err) - return 2 + return failure(*asJSON, exitUsage, "usage", "", err) } c := aggregator.New(n.AggregatorURL) snaps, err := c.ListCardanoDBSnapshots(ctx) if err != nil { - fmt.Fprintln(os.Stderr, "list:", err) - return exitNetwork + return failure(*asJSON, exitNetwork, "network", "list", err) } if *asJSON { return emitJSON(map[string]any{ @@ -167,8 +166,7 @@ func cmdShow(ctx context.Context, args []string) int { asJSON := fs.Bool("json", false, "emit structured JSON") n, rest, err := resolveNetwork(fs, args) if err != nil { - fmt.Fprintln(os.Stderr, err) - return exitUsage + return failure(*asJSON, exitUsage, "usage", "", err) } hash := "latest" if len(rest) > 0 { @@ -177,8 +175,7 @@ func cmdShow(ctx context.Context, args []string) int { c := aggregator.New(n.AggregatorURL) snap, err := resolveSnapshot(ctx, c, hash) if err != nil { - fmt.Fprintln(os.Stderr, "show:", err) - return exitNetwork + return failure(*asJSON, exitNetwork, "network", "show", err) } if *asJSON { return emitJSON(snap) @@ -282,12 +279,11 @@ func cmdVerify(ctx context.Context, args []string) int { asJSON := fs.Bool("json", false, "emit structured JSON") n, rest, err := resolveNetwork(fs, args) if err != nil { - fmt.Fprintln(os.Stderr, err) - return exitUsage + return failure(*asJSON, exitUsage, "usage", "", err) } if len(rest) == 0 { - fmt.Fprintln(os.Stderr, "verify: cert hash required (or 'head' / 'genesis')") - return exitUsage + return failure(*asJSON, exitUsage, "usage", "", + fmt.Errorf("verify: cert hash required (or 'head' / 'genesis' / 'chain' / 'manifest ')")) } mode := rest[0] // "head" = verify head cert (STM, not yet), "genesis" = walk chain + verify genesis, or a specific hash c := aggregator.New(n.AggregatorURL) @@ -309,24 +305,21 @@ func cmdVerify(ctx context.Context, args []string) int { func runVerifyManifest(args []string, asJSON bool) int { if len(args) == 0 { - fmt.Fprintln(os.Stderr, "verify manifest: needs path to download dir (with digests/ + db/)") - return exitUsage + return failure(asJSON, exitUsage, "usage", "", + fmt.Errorf("verify manifest: needs path to download dir (with digests/ + db/)")) } dir := args[0] digestsPath, err := manifest.LocateDigests(filepath.Join(dir, "digests")) if err != nil { - fmt.Fprintln(os.Stderr, "locate digests.json:", err) - return exitGeneric + return failure(asJSON, exitGeneric, "internal", "locate digests.json", err) } entries, err := manifest.Load(digestsPath) if err != nil { - fmt.Fprintln(os.Stderr, "load manifest:", err) - return exitIntegrity + return failure(asJSON, exitIntegrity, "integrity", "load manifest", err) } res, err := manifest.Verify(entries, filepath.Join(dir, "db")) if err != nil { - fmt.Fprintln(os.Stderr, "verify manifest:", err) - return exitGeneric + return failure(asJSON, exitGeneric, "internal", "verify manifest", err) } if asJSON { code := emitJSON(res) @@ -349,13 +342,11 @@ func runVerifyChain(ctx context.Context, n networks.Network, asJSON bool) int { c := aggregator.New(n.AggregatorURL) snap, err := resolveSnapshot(ctx, c, "latest") if err != nil { - fmt.Fprintln(os.Stderr, "resolve:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "resolve", err) } res, err := chain.Verify(ctx, nil, n, snap.CertificateHash, 2048) if err != nil { - fmt.Fprintln(os.Stderr, "chain verify:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "chain verify", err) } if asJSON { code := emitJSON(res) @@ -390,22 +381,19 @@ func runVerifyGenesis(ctx context.Context, c *aggregator.Client, n networks.Netw // Find the head snapshot's cert, walk to genesis, verify Ed25519 on the genesis cert. snap, err := resolveSnapshot(ctx, c, "latest") if err != nil { - fmt.Fprintln(os.Stderr, "resolve:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "resolve", err) } - chain, err := c.CertChain(ctx, snap.CertificateHash, 2048) + certs, err := c.CertChain(ctx, snap.CertificateHash, 2048) if err != nil { - fmt.Fprintln(os.Stderr, "chain:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "chain", err) } - if len(chain) == 0 { - fmt.Fprintln(os.Stderr, "empty chain") - return exitGeneric + if len(certs) == 0 { + return failure(asJSON, exitGeneric, "internal", "", fmt.Errorf("empty chain")) } - gen := chain[len(chain)-1] + gen := certs[len(certs)-1] if gen.GenesisSignature == "" { - fmt.Fprintln(os.Stderr, "tail of chain is not a genesis certificate") - return exitGeneric + return failure(asJSON, exitGeneric, "internal", "", + fmt.Errorf("tail of chain is not a genesis certificate")) } return verifyGenesisCert(n, gen, asJSON) } @@ -413,8 +401,7 @@ func runVerifyGenesis(ctx context.Context, c *aggregator.Client, n networks.Netw func runVerifyHead(ctx context.Context, c *aggregator.Client, n networks.Network, asJSON bool) int { snap, err := resolveSnapshot(ctx, c, "latest") if err != nil { - fmt.Fprintln(os.Stderr, "resolve:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "resolve", err) } return runVerifySingle(ctx, c, n, snap.CertificateHash, asJSON) } @@ -422,8 +409,7 @@ func runVerifyHead(ctx context.Context, c *aggregator.Client, n networks.Network func runVerifySingle(ctx context.Context, c *aggregator.Client, n networks.Network, hash string, asJSON bool) int { cert, err := c.GetCertificate(ctx, hash) if err != nil { - fmt.Fprintln(os.Stderr, "cert:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "cert", err) } if cert.GenesisSignature != "" { return verifyGenesisCert(n, cert, asJSON) @@ -438,18 +424,15 @@ func verifySTMCert(ctx context.Context, c *aggregator.Client, n networks.Network // Re-fetch as raw JSON to access the AVK + params fields. raw, err := fetchCertRaw(ctx, n.AggregatorURL, hash) if err != nil { - fmt.Fprintln(os.Stderr, "fetch raw cert:", err) - return exitNetwork + return failure(asJSON, exitNetwork, "network", "fetch raw cert", err) } ms, err := stm.DecodeMultiSig(raw.MultiSignature) if err != nil { - fmt.Fprintln(os.Stderr, "decode multi_signature:", err) - return exitIntegrity + return failure(asJSON, exitIntegrity, "integrity", "decode multi_signature", err) } avk, err := stm.DecodeAVK(raw.AggregateVerificationKey) if err != nil { - fmt.Fprintln(os.Stderr, "decode avk:", err) - return exitIntegrity + return failure(asJSON, exitIntegrity, "integrity", "decode avk", err) } msg := []byte(cert.SignedMessage) params := stm.Parameters{K: raw.Metadata.Parameters.K, M: raw.Metadata.Parameters.M, PhiF: raw.Metadata.Parameters.PhiF} @@ -501,14 +484,21 @@ func fetchCertRaw(ctx context.Context, aggregatorURL, hash string) (*rawCert, er if err != nil { return nil, err } + req.Header.Set("Accept", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode >= 400 { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return nil, fmt.Errorf("aggregator GET /certificate/%s: %d: %s", hash, resp.StatusCode, string(body)) + } + // Cap at 16 MiB — current mainnet cert JSON is well under 100 KiB. + limited := io.LimitReader(resp.Body, 16<<20) var r rawCert - if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { - return nil, err + if err := json.NewDecoder(limited).Decode(&r); err != nil { + return nil, fmt.Errorf("decode cert json: %w", err) } return &r, nil } @@ -516,8 +506,7 @@ func fetchCertRaw(ctx context.Context, aggregatorURL, hash string) (*rawCert, er func verifyGenesisCert(n networks.Network, cert *aggregator.Certificate, asJSON bool) int { vk, err := verify.DecodeGenesisVerifyKey(n.GenesisVerifyKey) if err != nil { - fmt.Fprintln(os.Stderr, "decode genesis key:", err) - return exitGeneric + return failure(asJSON, exitGeneric, "internal", "decode genesis key", err) } err = verify.GenesisFromJSON(vk, cert.SignedMessage, cert.GenesisSignature, cert.ProtocolMessage) if asJSON { @@ -552,28 +541,25 @@ func cmdCert(ctx context.Context, args []string) int { asJSON := fs.Bool("json", false, "emit structured JSON") n, rest, err := resolveNetwork(fs, args) if err != nil { - fmt.Fprintln(os.Stderr, err) - return exitUsage + return failure(*asJSON, exitUsage, "usage", "", err) } if len(rest) == 0 { - fmt.Fprintln(os.Stderr, "cert: hash required (or 'head' to use the latest snapshot's cert_hash)") - return exitUsage + return failure(*asJSON, exitUsage, "usage", "", + fmt.Errorf("cert: hash required (or 'head' to use the latest snapshot's cert_hash)")) } head := rest[0] c := aggregator.New(n.AggregatorURL) if head == "head" { snap, err := resolveSnapshot(ctx, c, "latest") if err != nil { - fmt.Fprintln(os.Stderr, "resolve head:", err) - return exitNetwork + return failure(*asJSON, exitNetwork, "network", "resolve head", err) } head = snap.CertificateHash } if *chain { certs, err := c.CertChain(ctx, head, *maxDepth) if err != nil { - fmt.Fprintln(os.Stderr, "chain:", err) - return exitNetwork + return failure(*asJSON, exitNetwork, "network", "chain", err) } if *asJSON { return emitJSON(map[string]any{"chain_length": len(certs), "certs": certs}) @@ -591,8 +577,7 @@ func cmdCert(ctx context.Context, args []string) int { } cert, err := c.GetCertificate(ctx, head) if err != nil { - fmt.Fprintln(os.Stderr, "cert:", err) - return exitNetwork + return failure(*asJSON, exitNetwork, "network", "cert", err) } if *asJSON { return emitJSON(cert) @@ -619,8 +604,7 @@ func cmdInfo(args []string) int { asJSON := fs.Bool("json", false, "emit structured JSON") n, _, err := resolveNetwork(fs, args) if err != nil { - fmt.Fprintln(os.Stderr, err) - return exitUsage + return failure(*asJSON, exitUsage, "usage", "", err) } if *asJSON { return emitJSON(map[string]any{ diff --git a/internal/artifact/download.go b/internal/artifact/download.go index e29546a..575d94c 100644 --- a/internal/artifact/download.go +++ b/internal/artifact/download.go @@ -121,6 +121,10 @@ func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progres if expectedSHA256 != "" { got := hex.EncodeToString(h.Sum(nil)) if got != expectedSHA256 { + // Remove the .part file — leaving it behind would cause every + // subsequent retry to resume from the same corrupted bytes and + // fail SHA again indefinitely. + _ = os.Remove(partPath) return fmt.Errorf("SHA256 mismatch: want %s, got %s", expectedSHA256, got) } } diff --git a/internal/stm/lottery.go b/internal/stm/lottery.go index 5c23ebc..6d6bd15 100644 --- a/internal/stm/lottery.go +++ b/internal/stm/lottery.go @@ -58,6 +58,12 @@ func IsLotteryWon(phiF float64, ev [64]byte, stake, totalStake uint64) bool { if math.Abs(phiF-1.0) < 1e-15 { return true } + // Defensive: zero-stake or zero-total-stake produces nonsense (and + // totalStake==0 would panic at SetFrac). Guard at the lottery layer + // in addition to AVK-decode-time validation. + if stake == 0 || totalStake == 0 { + return false + } // ev as big int (LE interpretation) evInt := evAsBigInt(ev) diff --git a/internal/stm/merkle.go b/internal/stm/merkle.go index 51c4211..9984997 100644 --- a/internal/stm/merkle.go +++ b/internal/stm/merkle.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/binary" "fmt" - "sort" "golang.org/x/crypto/blake2b" ) @@ -78,16 +77,31 @@ func VerifyMerkleBatch(root []byte, nrLeaves int, leafValues [][]byte, indices [ if len(leafValues) != len(indices) { return fmt.Errorf("leaves/indices count mismatch: %d vs %d", len(leafValues), len(indices)) } - // Must be sorted ascending + if nrLeaves <= 0 { + return fmt.Errorf("nrLeaves must be positive, got %d", nrLeaves) + } + // Validate every proof node is a 32-byte BLAKE2b-256 digest. Anything + // shorter or longer is malformed and Rust would reject it. + for i, v := range proofValues { + if len(v) != 32 { + return fmt.Errorf("proof value [%d]: got %d bytes, want 32", i, len(v)) + } + } + // Indices must be strictly ascending — duplicates would create + // double-claiming under the same leaf and the algorithm doesn't expect + // them. (Rust uses sort_unstable + equality compare against the input; + // equivalent to "non-decreasing" but doesn't reject equal-adjacent. + // We're stricter than upstream here on purpose.) ordered := make([]int, len(indices)) for i, v := range indices { + if v >= uint64(nrLeaves) { + return fmt.Errorf("index [%d]=%d out of range (nr_leaves=%d)", i, v, nrLeaves) + } ordered[i] = int(v) } - sortedCopy := append([]int(nil), ordered...) - sort.Ints(sortedCopy) - for i := range ordered { - if ordered[i] != sortedCopy[i] { - return fmt.Errorf("indices not sorted ascending: %v", indices) + for i := 1; i < len(ordered); i++ { + if ordered[i] <= ordered[i-1] { + return fmt.Errorf("indices not strictly ascending at [%d]: %v", i, indices) } } @@ -155,6 +169,12 @@ func VerifyMerkleBatch(root []byte, nrLeaves int, leafValues [][]byte, indices [ if len(currentLayer) != 1 { return fmt.Errorf("verification ended with %d nodes, want 1", len(currentLayer)) } + // All proof values must be consumed. Trailing bytes mean the proof + // shipped extra nodes the algorithm didn't need — likely malformed + // or attacker-padded. + if len(values) > 0 { + return fmt.Errorf("proof has %d unconsumed values — malformed", len(values)) + } if !bytes.Equal(currentLayer[0], root) { return fmt.Errorf("root mismatch: got %x, want %x", currentLayer[0], root) } diff --git a/internal/stm/types.go b/internal/stm/types.go index dd86488..6873067 100644 --- a/internal/stm/types.go +++ b/internal/stm/types.go @@ -116,6 +116,12 @@ func DecodeAVK(rawJSON []byte) (*AVK, error) { if len(wire.MTCommitment.Root) != 32 { return nil, fmt.Errorf("AVK root: got %d bytes, want 32", len(wire.MTCommitment.Root)) } + if wire.TotalStake == 0 { + return nil, fmt.Errorf("AVK total_stake is zero") + } + if wire.MTCommitment.NrLeaves == 0 { + return nil, fmt.Errorf("AVK nr_leaves is zero") + } return &AVK{ MerkleRoot: wire.MTCommitment.Root, NumLeaves: wire.MTCommitment.NrLeaves, diff --git a/internal/verify/verify.go b/internal/verify/verify.go index ab02fce..3fce58e 100644 --- a/internal/verify/verify.go +++ b/internal/verify/verify.go @@ -27,7 +27,6 @@ var ( ErrNotGenesis = errors.New("certificate is not a genesis certificate") ErrBadSignature = errors.New("genesis signature verification failed") ErrSignedMessageHash = errors.New("signed_message does not match SHA256(protocol_message)") - ErrSTMNotImplemented = errors.New("STM signature verification not implemented yet") ) // The Mithril enum order on ProtocolMessagePartKey — BTreeMap iteration @@ -182,9 +181,5 @@ func GenesisFromJSON(verifyKey ed25519.PublicKey, signedMessageHex, genesisSigna return Genesis(verifyKey, signedMessageHex, genesisSignatureHex, pm) } -// STM verifies a non-genesis certificate's aggregate BLS signature. -// Stub — target is Mithril STM paper §5 (signing) + §6 (aggregation) -// using gnark-crypto's bls12-381 primitives. -func STM(protocolMessageJSON, multiSignature []byte, avk any) error { - return ErrSTMNotImplemented -} +// STM verification lives in the sibling internal/stm package — see +// stm.Verify(). This file is genesis-Ed25519-only.