// Package artifact handles downloading and extracting Mithril snapshot artifacts. package artifact import ( "context" "crypto/sha256" "encoding/hex" "errors" "fmt" "io" "net/http" "os" "path/filepath" "time" ) // ProgressFn is called with (bytesRead, totalBytes) where totalBytes is // taken from the response Content-Length (0 if unknown — server didn't send). type ProgressFn func(read, total int64) // Download fetches a URL to destPath, resuming from a .part file if one // exists. If expectedSHA256 is non-empty, the final file is integrity-checked. // Progress is reported via the supplied callback (called with current bytes). // // Design notes: // - No parallel chunks yet; a single streaming GET is fine for sub-GB // artifacts and keeps the first working version simple. Range-chunk // parallelism will land in v2 once extraction is end-to-end tested. // - Resume is implemented via the HTTP Range header against the existing // .part file size; falls back to full download if the server refuses. // - destPath is atomically replaced only after SHA validation passes. func Download(ctx context.Context, uri, destPath, expectedSHA256 string, progress ProgressFn) error { if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { return fmt.Errorf("mkdir: %w", err) } partPath := destPath + ".part" var existing int64 if fi, err := os.Stat(partPath); err == nil { existing = fi.Size() } req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil) if err != nil { return err } if existing > 0 { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", existing)) } client := &http.Client{Timeout: 0} // artifacts can be GB-scale resp, err := client.Do(req) if err != nil { return fmt.Errorf("GET %s: %w", uri, err) } defer resp.Body.Close() var out *os.File switch resp.StatusCode { case http.StatusPartialContent: out, err = os.OpenFile(partPath, os.O_APPEND|os.O_WRONLY, 0o644) case http.StatusOK: // Server ignored our range; start over. existing = 0 out, err = os.Create(partPath) default: body, _ := io.ReadAll(io.LimitReader(resp.Body, 2048)) return fmt.Errorf("GET %s: %d: %s", uri, resp.StatusCode, string(body)) } if err != nil { return fmt.Errorf("open part: %w", err) } defer out.Close() h := sha256.New() // If we're resuming, we need to re-hash the existing bytes. if existing > 0 { prev, err := os.Open(partPath) if err == nil { io.Copy(h, prev) prev.Close() } } // totalBytes = existing + response.ContentLength if server sent one; // for resumed partial responses Content-Length is the remaining, not total. var totalSize int64 if resp.ContentLength > 0 { totalSize = existing + resp.ContentLength } w := io.MultiWriter(out, h) read := existing buf := make([]byte, 256*1024) lastProgress := time.Now() for { n, rerr := resp.Body.Read(buf) if n > 0 { if _, werr := w.Write(buf[:n]); werr != nil { return fmt.Errorf("write: %w", werr) } read += int64(n) if progress != nil && time.Since(lastProgress) > 250*time.Millisecond { progress(read, totalSize) lastProgress = time.Now() } } if rerr == io.EOF { break } if rerr != nil { return fmt.Errorf("read: %w", rerr) } } if progress != nil { progress(read, totalSize) } if err := out.Close(); err != nil { return err } if expectedSHA256 != "" { got := hex.EncodeToString(h.Sum(nil)) if got != expectedSHA256 { return fmt.Errorf("SHA256 mismatch: want %s, got %s", expectedSHA256, got) } } return os.Rename(partPath, destPath) } var ErrNoLocations = errors.New("no download locations available") // DownloadFirst tries each URI in order until one succeeds. func DownloadFirst(ctx context.Context, uris []string, destPath, expectedSHA256 string, progress ProgressFn) error { if len(uris) == 0 { return ErrNoLocations } var lastErr error for _, uri := range uris { if err := Download(ctx, uri, destPath, expectedSHA256, progress); err != nil { lastErr = err continue } return nil } return fmt.Errorf("all locations failed: last error: %w", lastErr) }