diff --git a/cmd/mithril-go/main.go b/cmd/mithril-go/main.go index fb570c3..2872d42 100644 --- a/cmd/mithril-go/main.go +++ b/cmd/mithril-go/main.go @@ -48,6 +48,8 @@ func main() { os.Exit(cmdDownload(ctx, args)) case "verify": os.Exit(cmdVerify(ctx, args)) + case "cert": + os.Exit(cmdCert(ctx, args)) case "info": os.Exit(cmdInfo(args)) case "version", "--version", "-v": @@ -70,6 +72,7 @@ Usage: Commands: list List available cardano-database snapshots show Show detail for one snapshot (hash or "latest") + cert Show a certificate or walk the chain back to genesis download Download a snapshot to a target directory verify Verify an already-downloaded snapshot (not yet implemented) info Show network + aggregator info @@ -190,7 +193,7 @@ func cmdDownload(ctx context.Context, args []string) int { return 1 } digestsArchive := filepath.Join(*out, "digests.tar.zst") - if err := downloadWithBar(ctx, digestsURIs[0], digestsArchive, snap.Digests.SizeUncompressed); err != nil { + if err := downloadWithBar(ctx, digestsURIs[0], digestsArchive); err != nil { fmt.Fprintln(os.Stderr, "digests download:", err) return 1 } @@ -209,7 +212,7 @@ func cmdDownload(ctx context.Context, args []string) int { return 1 } anciArchive := filepath.Join(*out, "ancillary.tar.zst") - if err := downloadWithBar(ctx, anciURIs[0], anciArchive, snap.Ancillary.SizeUncompressed); err != nil { + if err := downloadWithBar(ctx, anciURIs[0], anciArchive); err != nil { fmt.Fprintln(os.Stderr, "ancillary download:", err) return 1 } @@ -235,6 +238,68 @@ func cmdVerify(ctx context.Context, args []string) int { return 1 } +func cmdCert(ctx context.Context, args []string) int { + fs := flag.NewFlagSet("cert", flag.ExitOnError) + chain := fs.Bool("chain", false, "walk previous_hash back to the genesis certificate") + maxDepth := fs.Int("max-depth", 1024, "chain walk safety cap") + n, rest, err := resolveNetwork(fs, args) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return 2 + } + if len(rest) == 0 { + fmt.Fprintln(os.Stderr, "cert: hash required (or 'head' to use the latest snapshot's cert_hash)") + return 2 + } + head := rest[0] + c := aggregator.New(n.AggregatorURL) + if head == "head" { + snap, err := resolveSnapshot(ctx, c, "latest") + if err != nil { + fmt.Fprintln(os.Stderr, "resolve head:", err) + return 1 + } + head = snap.CertificateHash + } + if *chain { + certs, err := c.CertChain(ctx, head, *maxDepth) + if err != nil { + fmt.Fprintln(os.Stderr, "chain:", err) + return 1 + } + fmt.Printf("chain length: %d (head → genesis)\n\n", len(certs)) + for i, ct := range certs { + role := "" + if ct.GenesisSignature != "" { + role = " [GENESIS]" + } + fmt.Printf("[%3d] %s epoch=%d prev=%.16s…%s\n", + i, ct.Hash, ct.Epoch, ct.PreviousHash, role) + } + return 0 + } + cert, err := c.GetCertificate(ctx, head) + if err != nil { + fmt.Fprintln(os.Stderr, "cert:", err) + return 1 + } + fmt.Printf("hash: %s\n", cert.Hash) + fmt.Printf("previous_hash: %s\n", cert.PreviousHash) + fmt.Printf("epoch: %d\n", cert.Epoch) + fmt.Printf("signed_message: %s\n", cert.SignedMessage) + fmt.Printf("genesis sig: %s\n", yesNo(cert.GenesisSignature != "")) + fmt.Printf("multi_signature: %d bytes (raw)\n", len(cert.Multisignature)) + fmt.Printf("protocol_message: %d bytes (raw)\n", len(cert.ProtocolMessage)) + return 0 +} + +func yesNo(b bool) string { + if b { + return "yes" + } + return "no" +} + func cmdInfo(args []string) int { fs := flag.NewFlagSet("info", flag.ExitOnError) n, _, err := resolveNetwork(fs, args) @@ -272,19 +337,19 @@ func cloudURIs(locs []aggregator.Location) []string { return out } -func downloadWithBar(ctx context.Context, uri, dest string, expectedSize uint64) error { +func downloadWithBar(ctx context.Context, uri, dest string) error { fmt.Printf(" %s\n", uri) start := time.Now() var last int64 - cb := func(b int64) { + cb := func(read, total int64) { elapsed := time.Since(start).Seconds() - rate := float64(b) / elapsed + rate := float64(read) / elapsed pct := "" - if expectedSize > 0 { - pct = fmt.Sprintf("%5.1f%% ", float64(b)/float64(expectedSize)*100) + if total > 0 { + pct = fmt.Sprintf("%5.1f%% ", float64(read)/float64(total)*100) } - fmt.Printf("\r %s%s @ %s/s ", pct, humanSize(uint64(b)), humanSize(uint64(rate))) - last = b + fmt.Printf("\r %s%s @ %s/s ", pct, humanSize(uint64(read)), humanSize(uint64(rate))) + last = read } err := artifact.Download(ctx, uri, dest, "", cb) fmt.Printf("\r %s in %s \n", humanSize(uint64(last)), time.Since(start).Round(time.Second)) diff --git a/internal/aggregator/client.go b/internal/aggregator/client.go index ff98a8f..36c75db 100644 --- a/internal/aggregator/client.go +++ b/internal/aggregator/client.go @@ -158,3 +158,32 @@ func (c *Client) GetCertificate(ctx context.Context, hash string) (*Certificate, } return &out, nil } + +// CertChain walks previous_hash backwards from headHash until it hits the +// first certificate that carries a genesis_signature. Returns the chain +// ordered head-first. Caller can invert if root-first is preferred. +// +// The chain length is usually 1-3 certs per epoch boundary; an unbounded +// walk would be a footgun so it caps at maxDepth. +func (c *Client) CertChain(ctx context.Context, headHash string, maxDepth int) ([]*Certificate, error) { + if maxDepth <= 0 { + maxDepth = 1024 + } + var chain []*Certificate + next := headHash + for i := 0; i < maxDepth; i++ { + if next == "" { + return nil, fmt.Errorf("chain broke at depth %d: no previous_hash", i) + } + cert, err := c.GetCertificate(ctx, next) + if err != nil { + return nil, fmt.Errorf("depth %d (%s): %w", i, next, err) + } + chain = append(chain, cert) + if cert.GenesisSignature != "" { + return chain, nil + } + next = cert.PreviousHash + } + return nil, fmt.Errorf("cert chain exceeded max depth %d without reaching genesis", maxDepth) +} diff --git a/internal/artifact/download.go b/internal/artifact/download.go index 5719fc8..e29546a 100644 --- a/internal/artifact/download.go +++ b/internal/artifact/download.go @@ -14,6 +14,10 @@ import ( "time" ) +// ProgressFn is called with (bytesRead, totalBytes) where totalBytes is +// taken from the response Content-Length (0 if unknown — server didn't send). +type ProgressFn func(read, total int64) + // Download fetches a URL to destPath, resuming from a .part file if one // exists. If expectedSHA256 is non-empty, the final file is integrity-checked. // Progress is reported via the supplied callback (called with current bytes). @@ -25,7 +29,7 @@ import ( // - Resume is implemented via the HTTP Range header against the existing // .part file size; falls back to full download if the server refuses. // - destPath is atomically replaced only after SHA validation passes. -func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progress func(bytes int64)) error { +func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progress ProgressFn) error { if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { return fmt.Errorf("mkdir: %w", err) } @@ -77,8 +81,15 @@ func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progres } } + // totalBytes = existing + response.ContentLength if server sent one; + // for resumed partial responses Content-Length is the remaining, not total. + var totalSize int64 + if resp.ContentLength > 0 { + totalSize = existing + resp.ContentLength + } + w := io.MultiWriter(out, h) - total := existing + read := existing buf := make([]byte, 256*1024) lastProgress := time.Now() for { @@ -87,9 +98,9 @@ func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progres if _, werr := w.Write(buf[:n]); werr != nil { return fmt.Errorf("write: %w", werr) } - total += int64(n) + read += int64(n) if progress != nil && time.Since(lastProgress) > 250*time.Millisecond { - progress(total) + progress(read, totalSize) lastProgress = time.Now() } } @@ -101,7 +112,7 @@ func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progres } } if progress != nil { - progress(total) + progress(read, totalSize) } if err := out.Close(); err != nil { return err @@ -120,7 +131,7 @@ func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progres var ErrNoLocations = errors.New("no download locations available") // DownloadFirst tries each URI in order until one succeeds. -func DownloadFirst(ctx context.Context, uris []string, destPath, expectedSHA256 string, progress func(int64)) error { +func DownloadFirst(ctx context.Context, uris []string, destPath, expectedSHA256 string, progress ProgressFn) error { if len(uris) == 0 { return ErrNoLocations }