download + extract pipeline

- artifact.Download: resumable HTTP with optional SHA256 check + progress cb
- artifact.ExtractZstdTar: streamed zstd+tar with tar-slip defense
- aggregator client matches real API shape (digests/immutables/ancillary blocks
  with URIHolder polymorphism for templated immutable URIs)
- cmd: show + download subcommands wired up
- end-to-end verified against preprod: digests archive pulls cleanly, yields
  16836-entry SHA manifest ready for verification sprint

deps: github.com/klauspost/compress (pure-go zstd)
This commit is contained in:
Kayos 2026-04-23 15:16:48 -07:00
parent f87b7fc3c4
commit e557d85d5a
6 changed files with 483 additions and 84 deletions

View file

@ -5,7 +5,8 @@
// //
// Subcommands: // Subcommands:
// list — list available cardano-database snapshots on an aggregator // list — list available cardano-database snapshots on an aggregator
// download — fetch a snapshot (verify + extract optional) // show — show full detail for a snapshot (or "latest")
// download — fetch a snapshot (digests + ancillary; optionally immutables)
// verify — verify an already-downloaded snapshot // verify — verify an already-downloaded snapshot
// info — show aggregator + network details // info — show aggregator + network details
package main package main
@ -15,28 +16,38 @@ import (
"flag" "flag"
"fmt" "fmt"
"os" "os"
"os/signal"
"path/filepath"
"syscall"
"text/tabwriter" "text/tabwriter"
"time"
"git.sulkta.coop/Sulkta-Coop/mithril-go/internal/aggregator" "git.sulkta.coop/Sulkta-Coop/mithril-go/internal/aggregator"
"git.sulkta.coop/Sulkta-Coop/mithril-go/internal/artifact"
"git.sulkta.coop/Sulkta-Coop/mithril-go/internal/networks" "git.sulkta.coop/Sulkta-Coop/mithril-go/internal/networks"
) )
const version = "0.0.1-dev" const version = "0.0.2-dev"
func main() { func main() {
if len(os.Args) < 2 { if len(os.Args) < 2 {
usage() usage()
os.Exit(2) os.Exit(2)
} }
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
cmd := os.Args[1] cmd := os.Args[1]
args := os.Args[2:] args := os.Args[2:]
switch cmd { switch cmd {
case "list": case "list":
os.Exit(cmdList(args)) os.Exit(cmdList(ctx, args))
case "show":
os.Exit(cmdShow(ctx, args))
case "download": case "download":
os.Exit(cmdDownload(args)) os.Exit(cmdDownload(ctx, args))
case "verify": case "verify":
os.Exit(cmdVerify(args)) os.Exit(cmdVerify(ctx, args))
case "info": case "info":
os.Exit(cmdInfo(args)) os.Exit(cmdInfo(args))
case "version", "--version", "-v": case "version", "--version", "-v":
@ -58,8 +69,9 @@ Usage:
Commands: Commands:
list List available cardano-database snapshots list List available cardano-database snapshots
download Download + verify + extract a snapshot show Show detail for one snapshot (hash or "latest")
verify Verify an already-downloaded snapshot download Download a snapshot to a target directory
verify Verify an already-downloaded snapshot (not yet implemented)
info Show network + aggregator info info Show network + aggregator info
version Print version version Print version
help Show this help help Show this help
@ -80,15 +92,15 @@ func resolveNetwork(fs *flag.FlagSet, args []string) (networks.Network, []string
return n, fs.Args(), nil return n, fs.Args(), nil
} }
func cmdList(args []string) int { func cmdList(ctx context.Context, args []string) int {
fs := flag.NewFlagSet("list", flag.ExitOnError) fs := flag.NewFlagSet("list", flag.ExitOnError)
n, _, err := resolveNetwork(fs, args) n, _, err := resolveNetwork(fs, args)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, err) fmt.Fprintln(os.Stderr, err)
return 2 return 2
} }
client := aggregator.New(n.AggregatorURL) c := aggregator.New(n.AggregatorURL)
snaps, err := client.ListCardanoDBSnapshots(context.Background()) snaps, err := c.ListCardanoDBSnapshots(ctx)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr, "list:", err) fmt.Fprintln(os.Stderr, "list:", err)
return 1 return 1
@ -102,19 +114,124 @@ func cmdList(args []string) int {
s.CreatedAt.UTC().Format("2006-01-02 15:04 MST")) s.CreatedAt.UTC().Format("2006-01-02 15:04 MST"))
} }
if err := tw.Flush(); err != nil { if err := tw.Flush(); err != nil {
fmt.Fprintln(os.Stderr, "flush:", err)
return 1 return 1
} }
return 0 return 0
} }
func cmdDownload(args []string) int { func cmdShow(ctx context.Context, args []string) int {
fmt.Fprintln(os.Stderr, "download: not yet implemented") fs := flag.NewFlagSet("show", flag.ExitOnError)
n, rest, err := resolveNetwork(fs, args)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
hash := "latest"
if len(rest) > 0 {
hash = rest[0]
}
c := aggregator.New(n.AggregatorURL)
snap, err := resolveSnapshot(ctx, c, hash)
if err != nil {
fmt.Fprintln(os.Stderr, "show:", err)
return 1
}
fmt.Printf("hash: %s\n", snap.Hash)
fmt.Printf("network: %s\n", snap.Network)
fmt.Printf("epoch/immutable: %d / %d\n", snap.Beacon.Epoch, snap.Beacon.ImmutableFileNumber)
fmt.Printf("certificate: %s\n", snap.CertificateHash)
fmt.Printf("cardano version: %s\n", snap.CardanoNodeVersion)
fmt.Printf("created: %s\n", snap.CreatedAt.UTC().Format(time.RFC3339))
fmt.Printf("size uncompressed: %s\n", humanSize(snap.TotalDBSizeUncompressed))
fmt.Printf("digests size: %s locations: %d\n", humanSize(snap.Digests.SizeUncompressed), len(snap.Digests.Locations))
fmt.Printf("ancillary size: %s locations: %d\n", humanSize(snap.Ancillary.SizeUncompressed), len(snap.Ancillary.Locations))
fmt.Printf("immutable avg: %s files: %d locations: %d\n",
humanSize(snap.Immutables.AverageSizeUncompressed), snap.Beacon.ImmutableFileNumber, len(snap.Immutables.Locations))
return 0
}
func cmdDownload(ctx context.Context, args []string) int {
fs := flag.NewFlagSet("download", flag.ExitOnError)
out := fs.String("out", "./db", "output directory")
includeAncillary := fs.Bool("ancillary", true, "download the ancillary archive")
includeImmuts := fs.Bool("immutables", false, "download all immutable files (huge on mainnet — off by default)")
n, rest, err := resolveNetwork(fs, args)
if err != nil {
fmt.Fprintln(os.Stderr, err)
return 2
}
hash := "latest"
if len(rest) > 0 {
hash = rest[0]
}
c := aggregator.New(n.AggregatorURL)
snap, err := resolveSnapshot(ctx, c, hash)
if err != nil {
fmt.Fprintln(os.Stderr, "download:", err)
return 1 return 1
} }
func cmdVerify(args []string) int { fmt.Printf("Target: %s\n", snap.Hash)
fmt.Fprintln(os.Stderr, "verify: not yet implemented") fmt.Printf("Network: %s\n", snap.Network)
fmt.Printf("Epoch/Imm: %d / %d\n", snap.Beacon.Epoch, snap.Beacon.ImmutableFileNumber)
fmt.Printf("Output dir: %s\n", *out)
fmt.Println()
if err := os.MkdirAll(*out, 0o755); err != nil {
fmt.Fprintln(os.Stderr, "mkdir:", err)
return 1
}
// 1. Download + extract digests archive (few MB — always)
fmt.Println("=== digests ===")
digestsURIs := cloudURIs(snap.Digests.Locations)
if len(digestsURIs) == 0 {
fmt.Fprintln(os.Stderr, "no cloud_storage digest location available")
return 1
}
digestsArchive := filepath.Join(*out, "digests.tar.zst")
if err := downloadWithBar(ctx, digestsURIs[0], digestsArchive, snap.Digests.SizeUncompressed); err != nil {
fmt.Fprintln(os.Stderr, "digests download:", err)
return 1
}
if err := artifact.ExtractZstdTar(ctx, digestsArchive, filepath.Join(*out, "digests")); err != nil {
fmt.Fprintln(os.Stderr, "digests extract:", err)
return 1
}
fmt.Println(" extracted to", filepath.Join(*out, "digests"))
// 2. Ancillary archive
if *includeAncillary {
fmt.Println("\n=== ancillary ===")
anciURIs := cloudURIs(snap.Ancillary.Locations)
if len(anciURIs) == 0 {
fmt.Fprintln(os.Stderr, "no cloud_storage ancillary location available")
return 1
}
anciArchive := filepath.Join(*out, "ancillary.tar.zst")
if err := downloadWithBar(ctx, anciURIs[0], anciArchive, snap.Ancillary.SizeUncompressed); err != nil {
fmt.Fprintln(os.Stderr, "ancillary download:", err)
return 1
}
if err := artifact.ExtractZstdTar(ctx, anciArchive, filepath.Join(*out, "db")); err != nil {
fmt.Fprintln(os.Stderr, "ancillary extract:", err)
return 1
}
fmt.Println(" extracted to", filepath.Join(*out, "db"))
}
// 3. Immutables (optional, huge on mainnet)
if *includeImmuts {
fmt.Fprintln(os.Stderr, "immutables download: not yet wired (will come in v0.0.3)")
return 1
}
fmt.Println("\nDone.")
return 0
}
func cmdVerify(ctx context.Context, args []string) int {
fmt.Fprintln(os.Stderr, "verify: not yet implemented (STM BLS sprint pending)")
return 1 return 1
} }
@ -131,17 +248,60 @@ func cmdInfo(args []string) int {
return 0 return 0
} }
func resolveSnapshot(ctx context.Context, c *aggregator.Client, hashOrLatest string) (*aggregator.CardanoDBSnapshot, error) {
if hashOrLatest == "latest" {
snaps, err := c.ListCardanoDBSnapshots(ctx)
if err != nil {
return nil, err
}
if len(snaps) == 0 {
return nil, fmt.Errorf("aggregator returned no snapshots")
}
hashOrLatest = snaps[0].Hash
}
return c.GetCardanoDBSnapshot(ctx, hashOrLatest)
}
func cloudURIs(locs []aggregator.Location) []string {
var out []string
for _, l := range locs {
if l.Type == "cloud_storage" && l.URI.Plain != "" {
out = append(out, l.URI.Plain)
}
}
return out
}
func downloadWithBar(ctx context.Context, uri, dest string, expectedSize uint64) error {
fmt.Printf(" %s\n", uri)
start := time.Now()
var last int64
cb := func(b int64) {
elapsed := time.Since(start).Seconds()
rate := float64(b) / elapsed
pct := ""
if expectedSize > 0 {
pct = fmt.Sprintf("%5.1f%% ", float64(b)/float64(expectedSize)*100)
}
fmt.Printf("\r %s%s @ %s/s ", pct, humanSize(uint64(b)), humanSize(uint64(rate)))
last = b
}
err := artifact.Download(ctx, uri, dest, "", cb)
fmt.Printf("\r %s in %s \n", humanSize(uint64(last)), time.Since(start).Round(time.Second))
return err
}
func humanSize(b uint64) string { func humanSize(b uint64) string {
const k = 1024 const k = 1024.0
if b < k { if b < 1024 {
return fmt.Sprintf("%dB", b) return fmt.Sprintf("%dB", b)
} }
units := []string{"K", "M", "G", "T"}
v := float64(b) v := float64(b)
u := 0 for _, u := range []string{"K", "M", "G", "T", "P"} {
for v >= k && u < len(units)-1 {
v /= k v /= k
u++ if v < k {
return fmt.Sprintf("%.1f%s", v, u)
} }
return fmt.Sprintf("%.1f%s", v, units[u]) }
return fmt.Sprintf("%.1fE", v/k)
} }

2
go.mod
View file

@ -1,3 +1,5 @@
module git.sulkta.coop/Sulkta-Coop/mithril-go module git.sulkta.coop/Sulkta-Coop/mithril-go
go 1.26 go 1.26
require github.com/klauspost/compress v1.18.5

2
go.sum Normal file
View file

@ -0,0 +1,2 @@
github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE=
github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ=

View file

@ -1,7 +1,7 @@
// Package aggregator is a thin HTTP client for the Mithril aggregator REST API. // Package aggregator is a thin HTTP client for the Mithril aggregator REST API.
// //
// Only the handful of endpoints needed for client-side snapshot workflows are // Only the endpoints needed for client-side snapshot workflows are exposed.
// exposed. Authentication is not required for the read paths used here. // Authentication is not required for the read paths used here.
package aggregator package aggregator
import ( import (
@ -23,13 +23,13 @@ type Client struct {
func New(baseURL string) *Client { func New(baseURL string) *Client {
return &Client{ return &Client{
baseURL: strings.TrimRight(baseURL, "/"), baseURL: strings.TrimRight(baseURL, "/"),
http: &http.Client{Timeout: 60 * time.Second}, http: &http.Client{Timeout: 120 * time.Second},
} }
} }
// CardanoDBSnapshot is the server-reported shape for /artifact/cardano-database/{hash}. // CardanoDBSnapshot is the server shape for /artifact/cardano-database and its
// Field set is trimmed to what the client actually consumes — full schema documented // /{hash} detail endpoint. The list response omits {digests,immutables,ancillary};
// at https://mithril.network/doc/aggregator-api/. // only the detail endpoint populates them.
type CardanoDBSnapshot struct { type CardanoDBSnapshot struct {
Hash string `json:"hash"` Hash string `json:"hash"`
MerkleRoot string `json:"merkle_root"` MerkleRoot string `json:"merkle_root"`
@ -37,10 +37,11 @@ type CardanoDBSnapshot struct {
Beacon Beacon `json:"beacon"` Beacon Beacon `json:"beacon"`
CertificateHash string `json:"certificate_hash"` CertificateHash string `json:"certificate_hash"`
TotalDBSizeUncompressed uint64 `json:"total_db_size_uncompressed"` TotalDBSizeUncompressed uint64 `json:"total_db_size_uncompressed"`
Digests LocationList `json:"digests"` CardanoNodeVersion string `json:"cardano_node_version"`
ImmutablesAncillary LocationList `json:"immutables"`
ImmutablesIncremental *IncrementalImmutables `json:"immutables_incremental,omitempty"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Digests DigestsBlock `json:"digests"`
Immutables ImmutsBlock `json:"immutables"`
Ancillary AncillaryBlock `json:"ancillary"`
} }
type Beacon struct { type Beacon struct {
@ -48,24 +49,64 @@ type Beacon struct {
ImmutableFileNumber uint64 `json:"immutable_file_number"` ImmutableFileNumber uint64 `json:"immutable_file_number"`
} }
type LocationList struct { type DigestsBlock struct {
Size uint64 `json:"size"` SizeUncompressed uint64 `json:"size_uncompressed"`
Locations []LocationAlt `json:"locations"` Locations []Location `json:"locations"`
} }
// LocationAlt is a best-of alternative; Mithril returns a typed-discriminated object. type ImmutsBlock struct {
type LocationAlt struct { AverageSizeUncompressed uint64 `json:"average_size_uncompressed"`
Type string `json:"type"` // e.g. "cloud_storage", "ipfs" Locations []Location `json:"locations"`
URI string `json:"uri"`
} }
type IncrementalImmutables struct { type AncillaryBlock struct {
AverageSize uint64 `json:"average_size"` SizeUncompressed uint64 `json:"size_uncompressed"`
Locations []LocationAlt `json:"locations"` Locations []Location `json:"locations"`
}
// Location is a polymorphic URI holder. The Mithril API ships URI as either
// a plain string (for single artifacts) or as {"Template": "..."} for
// templated per-file URIs (immutables only).
type Location struct {
Type string `json:"type"`
URI URIHolder `json:"uri"`
CompressionAlgorithm string `json:"compression_algorithm,omitempty"`
}
// URIHolder absorbs both string and templated-object URI shapes.
type URIHolder struct {
Plain string
Template string
}
func (h *URIHolder) UnmarshalJSON(b []byte) error {
// Try plain string first
var s string
if err := json.Unmarshal(b, &s); err == nil {
h.Plain = s
return nil
}
// Fall back to {"Template": "..."}
var t struct {
Template string `json:"Template"`
}
if err := json.Unmarshal(b, &t); err == nil {
h.Template = t.Template
return nil
}
return fmt.Errorf("unrecognized URI shape: %s", string(b))
}
// String returns whichever URI form is populated.
func (h URIHolder) String() string {
if h.Template != "" {
return h.Template
}
return h.Plain
} }
// Certificate is the server-reported shape for /certificate/{hash}. // Certificate is the server-reported shape for /certificate/{hash}.
// Kept minimal; STM verification reads what it needs from the raw JSON later. // Kept wide — STM verification consumes raw bytes separately from the decoded view.
type Certificate struct { type Certificate struct {
Hash string `json:"hash"` Hash string `json:"hash"`
PreviousHash string `json:"previous_hash"` PreviousHash string `json:"previous_hash"`
@ -84,12 +125,12 @@ func (c *Client) get(ctx context.Context, path string, out any) error {
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
return fmt.Errorf("aggregator GET %s: %w", path, err) return fmt.Errorf("GET %s: %w", path, err)
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
return fmt.Errorf("aggregator GET %s: status %d: %s", path, resp.StatusCode, string(body)) return fmt.Errorf("GET %s: %d: %s", path, resp.StatusCode, string(body))
} }
if out == nil { if out == nil {
return nil return nil
@ -97,16 +138,11 @@ func (c *Client) get(ctx context.Context, path string, out any) error {
return json.NewDecoder(resp.Body).Decode(out) return json.NewDecoder(resp.Body).Decode(out)
} }
// ListCardanoDBSnapshots returns the sorted-newest-first list of cardano-database snapshots.
func (c *Client) ListCardanoDBSnapshots(ctx context.Context) ([]CardanoDBSnapshot, error) { func (c *Client) ListCardanoDBSnapshots(ctx context.Context) ([]CardanoDBSnapshot, error) {
var out []CardanoDBSnapshot var out []CardanoDBSnapshot
if err := c.get(ctx, "/artifact/cardano-database", &out); err != nil { return out, c.get(ctx, "/artifact/cardano-database", &out)
return nil, err
}
return out, nil
} }
// GetCardanoDBSnapshot fetches details for a single snapshot by hash (or "latest").
func (c *Client) GetCardanoDBSnapshot(ctx context.Context, hash string) (*CardanoDBSnapshot, error) { func (c *Client) GetCardanoDBSnapshot(ctx context.Context, hash string) (*CardanoDBSnapshot, error) {
var out CardanoDBSnapshot var out CardanoDBSnapshot
if err := c.get(ctx, "/artifact/cardano-database/"+url.PathEscape(hash), &out); err != nil { if err := c.get(ctx, "/artifact/cardano-database/"+url.PathEscape(hash), &out); err != nil {
@ -115,7 +151,6 @@ func (c *Client) GetCardanoDBSnapshot(ctx context.Context, hash string) (*Cardan
return &out, nil return &out, nil
} }
// GetCertificate fetches a certificate by hash for signature verification.
func (c *Client) GetCertificate(ctx context.Context, hash string) (*Certificate, error) { func (c *Client) GetCertificate(ctx context.Context, hash string) (*Certificate, error) {
var out Certificate var out Certificate
if err := c.get(ctx, "/certificate/"+url.PathEscape(hash), &out); err != nil { if err := c.get(ctx, "/certificate/"+url.PathEscape(hash), &out); err != nil {

View file

@ -1,27 +1,136 @@
// Package artifact handles downloading and extracting Mithril snapshot artifacts. // Package artifact handles downloading and extracting Mithril snapshot artifacts.
// Currently stubs — HTTP range requests, resumable downloads, zstd+tar extraction
// will be implemented in the next pass.
package artifact package artifact
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors" "errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
) )
var ErrNotImplemented = errors.New("not yet implemented") // Download fetches a URL to destPath, resuming from a .part file if one
// exists. If expectedSHA256 is non-empty, the final file is integrity-checked.
// Download fetches an artifact from one of the supplied locations, choosing // Progress is reported via the supplied callback (called with current bytes).
// the first reachable one and storing it at destPath. //
// Implementation will do: // Design notes:
// - parallel range-chunks over HTTP // - No parallel chunks yet; a single streaming GET is fine for sub-GB
// - resume on partial .part file // artifacts and keeps the first working version simple. Range-chunk
// - SHA-256 verification against the snapshot manifest // parallelism will land in v2 once extraction is end-to-end tested.
func Download(ctx context.Context, locations []string, destPath string) error { // - Resume is implemented via the HTTP Range header against the existing
return ErrNotImplemented // .part file size; falls back to full download if the server refuses.
// - destPath is atomically replaced only after SHA validation passes.
func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progress func(bytes int64)) error {
if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil {
return fmt.Errorf("mkdir: %w", err)
}
partPath := destPath + ".part"
var existing int64
if fi, err := os.Stat(partPath); err == nil {
existing = fi.Size()
} }
// Extract decompresses a zstd+tar archive into targetDir. req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
// Will stream through zstd -> tar reader without buffering the full archive. if err != nil {
func Extract(ctx context.Context, archivePath, targetDir string) error { return err
return ErrNotImplemented }
if existing > 0 {
req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existing))
}
client := &http.Client{Timeout: 0} // artifacts can be GB-scale
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("GET %s: %w", uri, err)
}
defer resp.Body.Close()
var out *os.File
switch resp.StatusCode {
case http.StatusPartialContent:
out, err = os.OpenFile(partPath, os.O_APPEND|os.O_WRONLY, 0o644)
case http.StatusOK:
// Server ignored our range; start over.
existing = 0
out, err = os.Create(partPath)
default:
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048))
return fmt.Errorf("GET %s: %d: %s", uri, resp.StatusCode, string(body))
}
if err != nil {
return fmt.Errorf("open part: %w", err)
}
defer out.Close()
h := sha256.New()
// If we're resuming, we need to re-hash the existing bytes.
if existing > 0 {
prev, err := os.Open(partPath)
if err == nil {
io.Copy(h, prev)
prev.Close()
}
}
w := io.MultiWriter(out, h)
total := existing
buf := make([]byte, 256*1024)
lastProgress := time.Now()
for {
n, rerr := resp.Body.Read(buf)
if n > 0 {
if _, werr := w.Write(buf[:n]); werr != nil {
return fmt.Errorf("write: %w", werr)
}
total += int64(n)
if progress != nil && time.Since(lastProgress) > 250*time.Millisecond {
progress(total)
lastProgress = time.Now()
}
}
if rerr == io.EOF {
break
}
if rerr != nil {
return fmt.Errorf("read: %w", rerr)
}
}
if progress != nil {
progress(total)
}
if err := out.Close(); err != nil {
return err
}
if expectedSHA256 != "" {
got := hex.EncodeToString(h.Sum(nil))
if got != expectedSHA256 {
return fmt.Errorf("SHA256 mismatch: want %s, got %s", expectedSHA256, got)
}
}
return os.Rename(partPath, destPath)
}
var ErrNoLocations = errors.New("no download locations available")
// DownloadFirst tries each URI in order until one succeeds.
func DownloadFirst(ctx context.Context, uris []string, destPath, expectedSHA256 string, progress func(int64)) error {
if len(uris) == 0 {
return ErrNoLocations
}
var lastErr error
for _, uri := range uris {
if err := Download(ctx, uri, destPath, expectedSHA256, progress); err != nil {
lastErr = err
continue
}
return nil
}
return fmt.Errorf("all locations failed: last error: %w", lastErr)
} }

View file

@ -0,0 +1,91 @@
package artifact
import (
"archive/tar"
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/klauspost/compress/zstd"
)
// ExtractZstdTar decompresses a .tar.zst archive into targetDir, streaming
// through the reader without buffering the full archive. Refuses entries
// with ".." in the path or absolute paths (tar-slip defense).
func ExtractZstdTar(ctx context.Context, archivePath, targetDir string) error {
f, err := os.Open(archivePath)
if err != nil {
return fmt.Errorf("open archive: %w", err)
}
defer f.Close()
zr, err := zstd.NewReader(f)
if err != nil {
return fmt.Errorf("zstd reader: %w", err)
}
defer zr.Close()
tr := tar.NewReader(zr)
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("mkdir target: %w", err)
}
cleanTarget, err := filepath.Abs(targetDir)
if err != nil {
return err
}
for {
if err := ctx.Err(); err != nil {
return err
}
hdr, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("tar next: %w", err)
}
// tar-slip defense
clean := filepath.Clean(hdr.Name)
if strings.HasPrefix(clean, "..") || filepath.IsAbs(clean) {
return fmt.Errorf("refusing suspicious archive path: %s", hdr.Name)
}
outPath := filepath.Join(cleanTarget, clean)
if !strings.HasPrefix(filepath.Clean(outPath)+string(os.PathSeparator), cleanTarget+string(os.PathSeparator)) &&
filepath.Clean(outPath) != cleanTarget {
return fmt.Errorf("refusing path outside target: %s", hdr.Name)
}
switch hdr.Typeflag {
case tar.TypeDir:
if err := os.MkdirAll(outPath, os.FileMode(hdr.Mode)); err != nil {
return err
}
case tar.TypeReg:
if err := os.MkdirAll(filepath.Dir(outPath), 0o755); err != nil {
return err
}
out, err := os.OpenFile(outPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(hdr.Mode))
if err != nil {
return fmt.Errorf("create %s: %w", outPath, err)
}
if _, err := io.Copy(out, tr); err != nil {
out.Close()
return fmt.Errorf("write %s: %w", outPath, err)
}
if err := out.Close(); err != nil {
return err
}
case tar.TypeSymlink, tar.TypeLink:
// Refuse links for safety — a Mithril archive has no legitimate reason to contain them.
return fmt.Errorf("refusing link entry: %s", hdr.Name)
default:
// Silently skip unknown types.
}
}
return nil
}