diff --git a/cmd/mithril-go/main.go b/cmd/mithril-go/main.go index 924fbf3..fb570c3 100644 --- a/cmd/mithril-go/main.go +++ b/cmd/mithril-go/main.go @@ -5,7 +5,8 @@ // // Subcommands: // list — list available cardano-database snapshots on an aggregator -// download — fetch a snapshot (verify + extract optional) +// show — show full detail for a snapshot (or "latest") +// download — fetch a snapshot (digests + ancillary; optionally immutables) // verify — verify an already-downloaded snapshot // info — show aggregator + network details package main @@ -15,28 +16,38 @@ import ( "flag" "fmt" "os" + "os/signal" + "path/filepath" + "syscall" "text/tabwriter" + "time" "git.sulkta.coop/Sulkta-Coop/mithril-go/internal/aggregator" + "git.sulkta.coop/Sulkta-Coop/mithril-go/internal/artifact" "git.sulkta.coop/Sulkta-Coop/mithril-go/internal/networks" ) -const version = "0.0.1-dev" +const version = "0.0.2-dev" func main() { if len(os.Args) < 2 { usage() os.Exit(2) } + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + cmd := os.Args[1] args := os.Args[2:] switch cmd { case "list": - os.Exit(cmdList(args)) + os.Exit(cmdList(ctx, args)) + case "show": + os.Exit(cmdShow(ctx, args)) case "download": - os.Exit(cmdDownload(args)) + os.Exit(cmdDownload(ctx, args)) case "verify": - os.Exit(cmdVerify(args)) + os.Exit(cmdVerify(ctx, args)) case "info": os.Exit(cmdInfo(args)) case "version", "--version", "-v": @@ -58,8 +69,9 @@ Usage: Commands: list List available cardano-database snapshots - download Download + verify + extract a snapshot - verify Verify an already-downloaded snapshot + show Show detail for one snapshot (hash or "latest") + download Download a snapshot to a target directory + verify Verify an already-downloaded snapshot (not yet implemented) info Show network + aggregator info version Print version help Show this help @@ -80,15 +92,15 @@ func resolveNetwork(fs *flag.FlagSet, args []string) (networks.Network, []string return n, fs.Args(), nil } -func cmdList(args []string) int { +func cmdList(ctx context.Context, args []string) int { fs := flag.NewFlagSet("list", flag.ExitOnError) n, _, err := resolveNetwork(fs, args) if err != nil { fmt.Fprintln(os.Stderr, err) return 2 } - client := aggregator.New(n.AggregatorURL) - snaps, err := client.ListCardanoDBSnapshots(context.Background()) + c := aggregator.New(n.AggregatorURL) + snaps, err := c.ListCardanoDBSnapshots(ctx) if err != nil { fmt.Fprintln(os.Stderr, "list:", err) return 1 @@ -102,19 +114,124 @@ func cmdList(args []string) int { s.CreatedAt.UTC().Format("2006-01-02 15:04 MST")) } if err := tw.Flush(); err != nil { - fmt.Fprintln(os.Stderr, "flush:", err) return 1 } return 0 } -func cmdDownload(args []string) int { - fmt.Fprintln(os.Stderr, "download: not yet implemented") - return 1 +func cmdShow(ctx context.Context, args []string) int { + fs := flag.NewFlagSet("show", flag.ExitOnError) + n, rest, err := resolveNetwork(fs, args) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return 2 + } + hash := "latest" + if len(rest) > 0 { + hash = rest[0] + } + c := aggregator.New(n.AggregatorURL) + snap, err := resolveSnapshot(ctx, c, hash) + if err != nil { + fmt.Fprintln(os.Stderr, "show:", err) + return 1 + } + fmt.Printf("hash: %s\n", snap.Hash) + fmt.Printf("network: %s\n", snap.Network) + fmt.Printf("epoch/immutable: %d / %d\n", snap.Beacon.Epoch, snap.Beacon.ImmutableFileNumber) + fmt.Printf("certificate: %s\n", snap.CertificateHash) + fmt.Printf("cardano version: %s\n", snap.CardanoNodeVersion) + fmt.Printf("created: %s\n", snap.CreatedAt.UTC().Format(time.RFC3339)) + fmt.Printf("size uncompressed: %s\n", humanSize(snap.TotalDBSizeUncompressed)) + fmt.Printf("digests size: %s locations: %d\n", humanSize(snap.Digests.SizeUncompressed), len(snap.Digests.Locations)) + fmt.Printf("ancillary size: %s locations: %d\n", humanSize(snap.Ancillary.SizeUncompressed), len(snap.Ancillary.Locations)) + fmt.Printf("immutable avg: %s files: %d locations: %d\n", + humanSize(snap.Immutables.AverageSizeUncompressed), snap.Beacon.ImmutableFileNumber, len(snap.Immutables.Locations)) + return 0 } -func cmdVerify(args []string) int { - fmt.Fprintln(os.Stderr, "verify: not yet implemented") +func cmdDownload(ctx context.Context, args []string) int { + fs := flag.NewFlagSet("download", flag.ExitOnError) + out := fs.String("out", "./db", "output directory") + includeAncillary := fs.Bool("ancillary", true, "download the ancillary archive") + includeImmuts := fs.Bool("immutables", false, "download all immutable files (huge on mainnet — off by default)") + n, rest, err := resolveNetwork(fs, args) + if err != nil { + fmt.Fprintln(os.Stderr, err) + return 2 + } + hash := "latest" + if len(rest) > 0 { + hash = rest[0] + } + c := aggregator.New(n.AggregatorURL) + snap, err := resolveSnapshot(ctx, c, hash) + if err != nil { + fmt.Fprintln(os.Stderr, "download:", err) + return 1 + } + + fmt.Printf("Target: %s\n", snap.Hash) + fmt.Printf("Network: %s\n", snap.Network) + fmt.Printf("Epoch/Imm: %d / %d\n", snap.Beacon.Epoch, snap.Beacon.ImmutableFileNumber) + fmt.Printf("Output dir: %s\n", *out) + fmt.Println() + + if err := os.MkdirAll(*out, 0o755); err != nil { + fmt.Fprintln(os.Stderr, "mkdir:", err) + return 1 + } + + // 1. Download + extract digests archive (few MB — always) + fmt.Println("=== digests ===") + digestsURIs := cloudURIs(snap.Digests.Locations) + if len(digestsURIs) == 0 { + fmt.Fprintln(os.Stderr, "no cloud_storage digest location available") + return 1 + } + digestsArchive := filepath.Join(*out, "digests.tar.zst") + if err := downloadWithBar(ctx, digestsURIs[0], digestsArchive, snap.Digests.SizeUncompressed); err != nil { + fmt.Fprintln(os.Stderr, "digests download:", err) + return 1 + } + if err := artifact.ExtractZstdTar(ctx, digestsArchive, filepath.Join(*out, "digests")); err != nil { + fmt.Fprintln(os.Stderr, "digests extract:", err) + return 1 + } + fmt.Println(" extracted to", filepath.Join(*out, "digests")) + + // 2. Ancillary archive + if *includeAncillary { + fmt.Println("\n=== ancillary ===") + anciURIs := cloudURIs(snap.Ancillary.Locations) + if len(anciURIs) == 0 { + fmt.Fprintln(os.Stderr, "no cloud_storage ancillary location available") + return 1 + } + anciArchive := filepath.Join(*out, "ancillary.tar.zst") + if err := downloadWithBar(ctx, anciURIs[0], anciArchive, snap.Ancillary.SizeUncompressed); err != nil { + fmt.Fprintln(os.Stderr, "ancillary download:", err) + return 1 + } + if err := artifact.ExtractZstdTar(ctx, anciArchive, filepath.Join(*out, "db")); err != nil { + fmt.Fprintln(os.Stderr, "ancillary extract:", err) + return 1 + } + fmt.Println(" extracted to", filepath.Join(*out, "db")) + } + + // 3. Immutables (optional, huge on mainnet) + if *includeImmuts { + fmt.Fprintln(os.Stderr, "immutables download: not yet wired (will come in v0.0.3)") + return 1 + } + + fmt.Println("\nDone.") + return 0 +} + +func cmdVerify(ctx context.Context, args []string) int { + fmt.Fprintln(os.Stderr, "verify: not yet implemented (STM BLS sprint pending)") return 1 } @@ -125,23 +242,66 @@ func cmdInfo(args []string) int { fmt.Fprintln(os.Stderr, err) return 2 } - fmt.Printf("network: %s\n", n.Name) - fmt.Printf("aggregator: %s\n", n.AggregatorURL) + fmt.Printf("network: %s\n", n.Name) + fmt.Printf("aggregator: %s\n", n.AggregatorURL) fmt.Printf("genesis verify key: %s…\n", n.GenesisVerifyKey[:16]) return 0 } +func resolveSnapshot(ctx context.Context, c *aggregator.Client, hashOrLatest string) (*aggregator.CardanoDBSnapshot, error) { + if hashOrLatest == "latest" { + snaps, err := c.ListCardanoDBSnapshots(ctx) + if err != nil { + return nil, err + } + if len(snaps) == 0 { + return nil, fmt.Errorf("aggregator returned no snapshots") + } + hashOrLatest = snaps[0].Hash + } + return c.GetCardanoDBSnapshot(ctx, hashOrLatest) +} + +func cloudURIs(locs []aggregator.Location) []string { + var out []string + for _, l := range locs { + if l.Type == "cloud_storage" && l.URI.Plain != "" { + out = append(out, l.URI.Plain) + } + } + return out +} + +func downloadWithBar(ctx context.Context, uri, dest string, expectedSize uint64) error { + fmt.Printf(" %s\n", uri) + start := time.Now() + var last int64 + cb := func(b int64) { + elapsed := time.Since(start).Seconds() + rate := float64(b) / elapsed + pct := "" + if expectedSize > 0 { + pct = fmt.Sprintf("%5.1f%% ", float64(b)/float64(expectedSize)*100) + } + fmt.Printf("\r %s%s @ %s/s ", pct, humanSize(uint64(b)), humanSize(uint64(rate))) + last = b + } + err := artifact.Download(ctx, uri, dest, "", cb) + fmt.Printf("\r %s in %s \n", humanSize(uint64(last)), time.Since(start).Round(time.Second)) + return err +} + func humanSize(b uint64) string { - const k = 1024 - if b < k { + const k = 1024.0 + if b < 1024 { return fmt.Sprintf("%dB", b) } - units := []string{"K", "M", "G", "T"} v := float64(b) - u := 0 - for v >= k && u < len(units)-1 { + for _, u := range []string{"K", "M", "G", "T", "P"} { v /= k - u++ + if v < k { + return fmt.Sprintf("%.1f%s", v, u) + } } - return fmt.Sprintf("%.1f%s", v, units[u]) + return fmt.Sprintf("%.1fE", v/k) } diff --git a/go.mod b/go.mod index f102270..f79f91d 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module git.sulkta.coop/Sulkta-Coop/mithril-go go 1.26 + +require github.com/klauspost/compress v1.18.5 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..1c48397 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= diff --git a/internal/aggregator/client.go b/internal/aggregator/client.go index 0f3e9ea..ff98a8f 100644 --- a/internal/aggregator/client.go +++ b/internal/aggregator/client.go @@ -1,7 +1,7 @@ // Package aggregator is a thin HTTP client for the Mithril aggregator REST API. // -// Only the handful of endpoints needed for client-side snapshot workflows are -// exposed. Authentication is not required for the read paths used here. +// Only the endpoints needed for client-side snapshot workflows are exposed. +// Authentication is not required for the read paths used here. package aggregator import ( @@ -23,57 +23,98 @@ type Client struct { func New(baseURL string) *Client { return &Client{ baseURL: strings.TrimRight(baseURL, "/"), - http: &http.Client{Timeout: 60 * time.Second}, + http: &http.Client{Timeout: 120 * time.Second}, } } -// CardanoDBSnapshot is the server-reported shape for /artifact/cardano-database/{hash}. -// Field set is trimmed to what the client actually consumes — full schema documented -// at https://mithril.network/doc/aggregator-api/. +// CardanoDBSnapshot is the server shape for /artifact/cardano-database and its +// /{hash} detail endpoint. The list response omits {digests,immutables,ancillary}; +// only the detail endpoint populates them. type CardanoDBSnapshot struct { - Hash string `json:"hash"` - MerkleRoot string `json:"merkle_root"` - Network string `json:"network"` - Beacon Beacon `json:"beacon"` - CertificateHash string `json:"certificate_hash"` - TotalDBSizeUncompressed uint64 `json:"total_db_size_uncompressed"` - Digests LocationList `json:"digests"` - ImmutablesAncillary LocationList `json:"immutables"` - ImmutablesIncremental *IncrementalImmutables `json:"immutables_incremental,omitempty"` - CreatedAt time.Time `json:"created_at"` + Hash string `json:"hash"` + MerkleRoot string `json:"merkle_root"` + Network string `json:"network"` + Beacon Beacon `json:"beacon"` + CertificateHash string `json:"certificate_hash"` + TotalDBSizeUncompressed uint64 `json:"total_db_size_uncompressed"` + CardanoNodeVersion string `json:"cardano_node_version"` + CreatedAt time.Time `json:"created_at"` + Digests DigestsBlock `json:"digests"` + Immutables ImmutsBlock `json:"immutables"` + Ancillary AncillaryBlock `json:"ancillary"` } type Beacon struct { - Epoch uint64 `json:"epoch"` + Epoch uint64 `json:"epoch"` ImmutableFileNumber uint64 `json:"immutable_file_number"` } -type LocationList struct { - Size uint64 `json:"size"` - Locations []LocationAlt `json:"locations"` +type DigestsBlock struct { + SizeUncompressed uint64 `json:"size_uncompressed"` + Locations []Location `json:"locations"` } -// LocationAlt is a best-of alternative; Mithril returns a typed-discriminated object. -type LocationAlt struct { - Type string `json:"type"` // e.g. "cloud_storage", "ipfs" - URI string `json:"uri"` +type ImmutsBlock struct { + AverageSizeUncompressed uint64 `json:"average_size_uncompressed"` + Locations []Location `json:"locations"` } -type IncrementalImmutables struct { - AverageSize uint64 `json:"average_size"` - Locations []LocationAlt `json:"locations"` +type AncillaryBlock struct { + SizeUncompressed uint64 `json:"size_uncompressed"` + Locations []Location `json:"locations"` +} + +// Location is a polymorphic URI holder. The Mithril API ships URI as either +// a plain string (for single artifacts) or as {"Template": "..."} for +// templated per-file URIs (immutables only). +type Location struct { + Type string `json:"type"` + URI URIHolder `json:"uri"` + CompressionAlgorithm string `json:"compression_algorithm,omitempty"` +} + +// URIHolder absorbs both string and templated-object URI shapes. +type URIHolder struct { + Plain string + Template string +} + +func (h *URIHolder) UnmarshalJSON(b []byte) error { + // Try plain string first + var s string + if err := json.Unmarshal(b, &s); err == nil { + h.Plain = s + return nil + } + // Fall back to {"Template": "..."} + var t struct { + Template string `json:"Template"` + } + if err := json.Unmarshal(b, &t); err == nil { + h.Template = t.Template + return nil + } + return fmt.Errorf("unrecognized URI shape: %s", string(b)) +} + +// String returns whichever URI form is populated. +func (h URIHolder) String() string { + if h.Template != "" { + return h.Template + } + return h.Plain } // Certificate is the server-reported shape for /certificate/{hash}. -// Kept minimal; STM verification reads what it needs from the raw JSON later. +// Kept wide — STM verification consumes raw bytes separately from the decoded view. type Certificate struct { - Hash string `json:"hash"` - PreviousHash string `json:"previous_hash"` - Epoch uint64 `json:"epoch"` - SignedMessage string `json:"signed_message"` - ProtocolMessage json.RawMessage `json:"protocol_message"` - Multisignature json.RawMessage `json:"multi_signature"` - GenesisSignature string `json:"genesis_signature,omitempty"` + Hash string `json:"hash"` + PreviousHash string `json:"previous_hash"` + Epoch uint64 `json:"epoch"` + SignedMessage string `json:"signed_message"` + ProtocolMessage json.RawMessage `json:"protocol_message"` + Multisignature json.RawMessage `json:"multi_signature"` + GenesisSignature string `json:"genesis_signature,omitempty"` } func (c *Client) get(ctx context.Context, path string, out any) error { @@ -84,12 +125,12 @@ func (c *Client) get(ctx context.Context, path string, out any) error { req.Header.Set("Accept", "application/json") resp, err := c.http.Do(req) if err != nil { - return fmt.Errorf("aggregator GET %s: %w", path, err) + return fmt.Errorf("GET %s: %w", path, err) } defer resp.Body.Close() if resp.StatusCode >= 400 { body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) - return fmt.Errorf("aggregator GET %s: status %d: %s", path, resp.StatusCode, string(body)) + return fmt.Errorf("GET %s: %d: %s", path, resp.StatusCode, string(body)) } if out == nil { return nil @@ -97,16 +138,11 @@ func (c *Client) get(ctx context.Context, path string, out any) error { return json.NewDecoder(resp.Body).Decode(out) } -// ListCardanoDBSnapshots returns the sorted-newest-first list of cardano-database snapshots. func (c *Client) ListCardanoDBSnapshots(ctx context.Context) ([]CardanoDBSnapshot, error) { var out []CardanoDBSnapshot - if err := c.get(ctx, "/artifact/cardano-database", &out); err != nil { - return nil, err - } - return out, nil + return out, c.get(ctx, "/artifact/cardano-database", &out) } -// GetCardanoDBSnapshot fetches details for a single snapshot by hash (or "latest"). func (c *Client) GetCardanoDBSnapshot(ctx context.Context, hash string) (*CardanoDBSnapshot, error) { var out CardanoDBSnapshot if err := c.get(ctx, "/artifact/cardano-database/"+url.PathEscape(hash), &out); err != nil { @@ -115,7 +151,6 @@ func (c *Client) GetCardanoDBSnapshot(ctx context.Context, hash string) (*Cardan return &out, nil } -// GetCertificate fetches a certificate by hash for signature verification. func (c *Client) GetCertificate(ctx context.Context, hash string) (*Certificate, error) { var out Certificate if err := c.get(ctx, "/certificate/"+url.PathEscape(hash), &out); err != nil { diff --git a/internal/artifact/download.go b/internal/artifact/download.go index 46f3a0d..5719fc8 100644 --- a/internal/artifact/download.go +++ b/internal/artifact/download.go @@ -1,27 +1,136 @@ // Package artifact handles downloading and extracting Mithril snapshot artifacts. -// Currently stubs — HTTP range requests, resumable downloads, zstd+tar extraction -// will be implemented in the next pass. package artifact import ( "context" + "crypto/sha256" + "encoding/hex" "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" ) -var ErrNotImplemented = errors.New("not yet implemented") +// Download fetches a URL to destPath, resuming from a .part file if one +// exists. If expectedSHA256 is non-empty, the final file is integrity-checked. +// Progress is reported via the supplied callback (called with current bytes). +// +// Design notes: +// - No parallel chunks yet; a single streaming GET is fine for sub-GB +// artifacts and keeps the first working version simple. Range-chunk +// parallelism will land in v2 once extraction is end-to-end tested. +// - Resume is implemented via the HTTP Range header against the existing +// .part file size; falls back to full download if the server refuses. +// - destPath is atomically replaced only after SHA validation passes. +func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progress func(bytes int64)) error { + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("mkdir: %w", err) + } + partPath := destPath + ".part" + var existing int64 + if fi, err := os.Stat(partPath); err == nil { + existing = fi.Size() + } -// Download fetches an artifact from one of the supplied locations, choosing -// the first reachable one and storing it at destPath. -// Implementation will do: -// - parallel range-chunks over HTTP -// - resume on partial .part file -// - SHA-256 verification against the snapshot manifest -func Download(ctx context.Context, locations []string, destPath string) error { - return ErrNotImplemented + req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) + if err != nil { + return err + } + if existing > 0 { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existing)) + } + + client := &http.Client{Timeout: 0} // artifacts can be GB-scale + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("GET %s: %w", uri, err) + } + defer resp.Body.Close() + + var out *os.File + switch resp.StatusCode { + case http.StatusPartialContent: + out, err = os.OpenFile(partPath, os.O_APPEND|os.O_WRONLY, 0o644) + case http.StatusOK: + // Server ignored our range; start over. + existing = 0 + out, err = os.Create(partPath) + default: + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) + return fmt.Errorf("GET %s: %d: %s", uri, resp.StatusCode, string(body)) + } + if err != nil { + return fmt.Errorf("open part: %w", err) + } + defer out.Close() + + h := sha256.New() + // If we're resuming, we need to re-hash the existing bytes. + if existing > 0 { + prev, err := os.Open(partPath) + if err == nil { + io.Copy(h, prev) + prev.Close() + } + } + + w := io.MultiWriter(out, h) + total := existing + buf := make([]byte, 256*1024) + lastProgress := time.Now() + for { + n, rerr := resp.Body.Read(buf) + if n > 0 { + if _, werr := w.Write(buf[:n]); werr != nil { + return fmt.Errorf("write: %w", werr) + } + total += int64(n) + if progress != nil && time.Since(lastProgress) > 250*time.Millisecond { + progress(total) + lastProgress = time.Now() + } + } + if rerr == io.EOF { + break + } + if rerr != nil { + return fmt.Errorf("read: %w", rerr) + } + } + if progress != nil { + progress(total) + } + if err := out.Close(); err != nil { + return err + } + + if expectedSHA256 != "" { + got := hex.EncodeToString(h.Sum(nil)) + if got != expectedSHA256 { + return fmt.Errorf("SHA256 mismatch: want %s, got %s", expectedSHA256, got) + } + } + + return os.Rename(partPath, destPath) } -// Extract decompresses a zstd+tar archive into targetDir. -// Will stream through zstd -> tar reader without buffering the full archive. -func Extract(ctx context.Context, archivePath, targetDir string) error { - return ErrNotImplemented +var ErrNoLocations = errors.New("no download locations available") + +// DownloadFirst tries each URI in order until one succeeds. +func DownloadFirst(ctx context.Context, uris []string, destPath, expectedSHA256 string, progress func(int64)) error { + if len(uris) == 0 { + return ErrNoLocations + } + var lastErr error + for _, uri := range uris { + if err := Download(ctx, uri, destPath, expectedSHA256, progress); err != nil { + lastErr = err + continue + } + return nil + } + return fmt.Errorf("all locations failed: last error: %w", lastErr) } diff --git a/internal/artifact/extract.go b/internal/artifact/extract.go new file mode 100644 index 0000000..0731078 --- /dev/null +++ b/internal/artifact/extract.go @@ -0,0 +1,91 @@ +package artifact + +import ( + "archive/tar" + "context" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/klauspost/compress/zstd" +) + +// ExtractZstdTar decompresses a .tar.zst archive into targetDir, streaming +// through the reader without buffering the full archive. Refuses entries +// with ".." in the path or absolute paths (tar-slip defense). +func ExtractZstdTar(ctx context.Context, archivePath, targetDir string) error { + f, err := os.Open(archivePath) + if err != nil { + return fmt.Errorf("open archive: %w", err) + } + defer f.Close() + + zr, err := zstd.NewReader(f) + if err != nil { + return fmt.Errorf("zstd reader: %w", err) + } + defer zr.Close() + + tr := tar.NewReader(zr) + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("mkdir target: %w", err) + } + + cleanTarget, err := filepath.Abs(targetDir) + if err != nil { + return err + } + + for { + if err := ctx.Err(); err != nil { + return err + } + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("tar next: %w", err) + } + // tar-slip defense + clean := filepath.Clean(hdr.Name) + if strings.HasPrefix(clean, "..") || filepath.IsAbs(clean) { + return fmt.Errorf("refusing suspicious archive path: %s", hdr.Name) + } + outPath := filepath.Join(cleanTarget, clean) + if !strings.HasPrefix(filepath.Clean(outPath)+string(os.PathSeparator), cleanTarget+string(os.PathSeparator)) && + filepath.Clean(outPath) != cleanTarget { + return fmt.Errorf("refusing path outside target: %s", hdr.Name) + } + + switch hdr.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(outPath, os.FileMode(hdr.Mode)); err != nil { + return err + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(outPath), 0o755); err != nil { + return err + } + out, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode)) + if err != nil { + return fmt.Errorf("create %s: %w", outPath, err) + } + if _, err := io.Copy(out, tr); err != nil { + out.Close() + return fmt.Errorf("write %s: %w", outPath, err) + } + if err := out.Close(); err != nil { + return err + } + case tar.TypeSymlink, tar.TypeLink: + // Refuse links for safety — a Mithril archive has no legitimate reason to contain them. + return fmt.Errorf("refusing link entry: %s", hdr.Name) + default: + // Silently skip unknown types. + } + } + return nil +}