- L1: gofmt fix on models.go:81 - L2: rewrite misleading RunFailure doc comment (didn't actually embed APIError) - L3: tighten Client doc to warn against post-construction field mutation - L4: errors.New for non-formatting Errorf calls - L5: add TestUploadFile lifting coverage from 0% → 100% on UploadFile - L7: add context cancellation mid-multipart test Audit: memory/clawdforge-audits/go-3c62613.md
562 lines
16 KiB
Go
562 lines
16 KiB
Go
package clawdforge
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// helper: spin up an httptest.Server with the supplied handler and return
|
|
// (client, teardown).
|
|
func newTestClient(t *testing.T, h http.HandlerFunc) (*Client, func()) {
|
|
t.Helper()
|
|
srv := httptest.NewServer(h)
|
|
c := New(srv.URL, "cf_test_token")
|
|
return c, srv.Close
|
|
}
|
|
|
|
func TestHealthz(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/healthz" {
|
|
t.Errorf("path = %q, want /healthz", r.URL.Path)
|
|
}
|
|
if r.Method != http.MethodGet {
|
|
t.Errorf("method = %q, want GET", r.Method)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"ok":true,"claude_present":true,"claude_version":"1.2.3"}`))
|
|
})
|
|
defer done()
|
|
|
|
h, err := c.Healthz(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("Healthz: %v", err)
|
|
}
|
|
if !h.OK || !h.ClaudePresent || h.ClaudeVersion != "1.2.3" {
|
|
t.Errorf("got %+v", h)
|
|
}
|
|
}
|
|
|
|
func TestRunSuccessJSON(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/run" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
if got := r.Header.Get("Authorization"); got != "Bearer cf_test_token" {
|
|
t.Errorf("Authorization = %q", got)
|
|
}
|
|
var body RunRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
if body.Prompt != "say hi" || body.Model != "sonnet" {
|
|
t.Errorf("body = %+v", body)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"ok":true,"result":{"hello":"world"},"duration_ms":42,"stop_reason":"end_turn"}`))
|
|
})
|
|
defer done()
|
|
|
|
res, err := c.Run(context.Background(), RunRequest{Prompt: "say hi", Model: "sonnet"})
|
|
if err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if !res.OK || res.DurationMS != 42 || res.StopReason != "end_turn" {
|
|
t.Errorf("got %+v", res)
|
|
}
|
|
var data map[string]string
|
|
if err := res.AsJSON(&data); err != nil {
|
|
t.Fatalf("AsJSON: %v", err)
|
|
}
|
|
if data["hello"] != "world" {
|
|
t.Errorf("AsJSON: got %v", data)
|
|
}
|
|
}
|
|
|
|
func TestRunSuccessText(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"ok":true,"result":"plain text reply","duration_ms":10,"stop_reason":"end_turn"}`))
|
|
})
|
|
defer done()
|
|
|
|
res, err := c.Run(context.Background(), RunRequest{Prompt: "hi"})
|
|
if err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
s, err := res.AsText()
|
|
if err != nil {
|
|
t.Fatalf("AsText: %v", err)
|
|
}
|
|
if s != "plain text reply" {
|
|
t.Errorf("AsText = %q", s)
|
|
}
|
|
// AsJSON on a string result should still work — string is valid JSON
|
|
var got string
|
|
if err := res.AsJSON(&got); err != nil || got != "plain text reply" {
|
|
t.Errorf("AsJSON-on-string: got=%q err=%v", got, err)
|
|
}
|
|
}
|
|
|
|
func TestRunFailure502(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
_, _ = w.Write([]byte(`{"ok":false,"error":"timeout after 30s","stderr":"...","duration_ms":30000,"stop_reason":"timeout"}`))
|
|
})
|
|
defer done()
|
|
|
|
_, err := c.Run(context.Background(), RunRequest{Prompt: "loop forever"})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var rf *RunFailure
|
|
if !errors.As(err, &rf) {
|
|
t.Fatalf("err is not *RunFailure: %T %v", err, err)
|
|
}
|
|
if rf.StopReason != "timeout" || rf.DurationMS != 30000 {
|
|
t.Errorf("got %+v", rf)
|
|
}
|
|
}
|
|
|
|
func TestRunAuthFailure(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
_, _ = w.Write([]byte(`{"detail":"missing bearer"}`))
|
|
})
|
|
defer done()
|
|
|
|
_, err := c.Run(context.Background(), RunRequest{Prompt: "x"})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
if !errors.Is(err, ErrAuth) {
|
|
t.Errorf("err is not ErrAuth: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestRunGenericAPIError(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
_, _ = w.Write([]byte(`{"error":"unknown file token: ff_xyz"}`))
|
|
})
|
|
defer done()
|
|
|
|
_, err := c.Run(context.Background(), RunRequest{Prompt: "x", Files: []string{"ff_xyz"}})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var apiErr *APIError
|
|
if !errors.As(err, &apiErr) {
|
|
t.Fatalf("err is not *APIError: %T %v", err, err)
|
|
}
|
|
if apiErr.StatusCode != 404 || !strings.Contains(apiErr.Message, "ff_xyz") {
|
|
t.Errorf("got %+v", apiErr)
|
|
}
|
|
}
|
|
|
|
func TestRunEmptyPromptRejected(t *testing.T) {
|
|
c := New("http://nowhere.invalid", "tok")
|
|
_, err := c.Run(context.Background(), RunRequest{})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty prompt")
|
|
}
|
|
}
|
|
|
|
func TestUploadReader(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/files" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
ct := r.Header.Get("Content-Type")
|
|
if !strings.HasPrefix(ct, "multipart/form-data") {
|
|
t.Fatalf("Content-Type = %q", ct)
|
|
}
|
|
mr, err := r.MultipartReader()
|
|
if err != nil {
|
|
t.Fatalf("MultipartReader: %v", err)
|
|
}
|
|
var (
|
|
gotTTL string
|
|
gotFile string
|
|
gotData []byte
|
|
)
|
|
for {
|
|
part, err := mr.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("NextPart: %v", err)
|
|
}
|
|
switch part.FormName() {
|
|
case "ttl_secs":
|
|
b, _ := io.ReadAll(part)
|
|
gotTTL = string(b)
|
|
case "file":
|
|
gotFile = part.FileName()
|
|
gotData, _ = io.ReadAll(part)
|
|
}
|
|
}
|
|
if gotTTL != "7200" {
|
|
t.Errorf("ttl_secs = %q", gotTTL)
|
|
}
|
|
if gotFile != "recipe.txt" {
|
|
t.Errorf("filename = %q", gotFile)
|
|
}
|
|
if string(gotData) != "hello world" {
|
|
t.Errorf("data = %q", string(gotData))
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"file_token":"ff_abc","ttl_secs":7200,"size":11}`))
|
|
})
|
|
defer done()
|
|
|
|
ft, err := c.UploadReader(context.Background(), "recipe.txt", strings.NewReader("hello world"), 7200)
|
|
if err != nil {
|
|
t.Fatalf("UploadReader: %v", err)
|
|
}
|
|
if ft.FileToken != "ff_abc" || ft.TTLSecs != 7200 || ft.Size != 11 {
|
|
t.Errorf("got %+v", ft)
|
|
}
|
|
}
|
|
|
|
// L5: exercise the file-path wrapper so coverage on UploadFile lifts off 0%.
|
|
// Writes a small tempfile, uploads it, and verifies the multipart envelope —
|
|
// notably that filepath.Base() is what hits the wire, not the full tmp path.
|
|
func TestUploadFile(t *testing.T) {
|
|
dir := t.TempDir()
|
|
fpath := filepath.Join(dir, "snippet.txt")
|
|
const payload = "uploaded-from-disk"
|
|
if err := os.WriteFile(fpath, []byte(payload), 0o600); err != nil {
|
|
t.Fatalf("WriteFile: %v", err)
|
|
}
|
|
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/files" {
|
|
t.Errorf("path = %q", r.URL.Path)
|
|
}
|
|
mr, err := r.MultipartReader()
|
|
if err != nil {
|
|
t.Fatalf("MultipartReader: %v", err)
|
|
}
|
|
var (
|
|
gotTTL string
|
|
gotFile string
|
|
gotData []byte
|
|
)
|
|
for {
|
|
part, err := mr.NextPart()
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
t.Fatalf("NextPart: %v", err)
|
|
}
|
|
switch part.FormName() {
|
|
case "ttl_secs":
|
|
b, _ := io.ReadAll(part)
|
|
gotTTL = string(b)
|
|
case "file":
|
|
gotFile = part.FileName()
|
|
gotData, _ = io.ReadAll(part)
|
|
}
|
|
}
|
|
if gotTTL != "120" {
|
|
t.Errorf("ttl_secs = %q", gotTTL)
|
|
}
|
|
// Server should see basename only, not the full tmpdir path.
|
|
if gotFile != "snippet.txt" {
|
|
t.Errorf("filename = %q, want basename only", gotFile)
|
|
}
|
|
if string(gotData) != payload {
|
|
t.Errorf("data = %q", string(gotData))
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"file_token":"ff_disk","ttl_secs":120,"size":18}`))
|
|
})
|
|
defer done()
|
|
|
|
ft, err := c.UploadFile(context.Background(), fpath, 120)
|
|
if err != nil {
|
|
t.Fatalf("UploadFile: %v", err)
|
|
}
|
|
if ft.FileToken != "ff_disk" || ft.TTLSecs != 120 || ft.Size != 18 {
|
|
t.Errorf("got %+v", ft)
|
|
}
|
|
|
|
// And the missing-file path must return a wrapped open error.
|
|
if _, err := c.UploadFile(context.Background(), filepath.Join(dir, "does-not-exist"), 0); err == nil {
|
|
t.Error("expected error opening nonexistent file")
|
|
}
|
|
}
|
|
|
|
func TestCreateAndListAndRevokeToken(t *testing.T) {
|
|
created := false
|
|
listed := false
|
|
revoked := false
|
|
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch {
|
|
case r.Method == http.MethodPost && r.URL.Path == "/admin/tokens":
|
|
var body CreateTokenRequest
|
|
_ = json.NewDecoder(r.Body).Decode(&body)
|
|
if body.Name != "petalparse" {
|
|
t.Errorf("name = %q", body.Name)
|
|
}
|
|
created = true
|
|
_, _ = w.Write([]byte(`{"name":"petalparse","token":"cf_freshplaintext","ip_cidrs":["172.24.0.0/16"]}`))
|
|
case r.Method == http.MethodGet && r.URL.Path == "/admin/tokens":
|
|
listed = true
|
|
_, _ = w.Write([]byte(`{"tokens":[{"name":"petalparse","ip_cidrs":["172.24.0.0/16"],"created_at":1700000000}]}`))
|
|
case r.Method == http.MethodDelete && r.URL.Path == "/admin/tokens/petalparse":
|
|
revoked = true
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
default:
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
})
|
|
defer done()
|
|
|
|
tok, err := c.CreateToken(context.Background(), CreateTokenRequest{
|
|
Name: "petalparse",
|
|
IPCidrs: []string{"172.24.0.0/16"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("CreateToken: %v", err)
|
|
}
|
|
if tok.Token != "cf_freshplaintext" {
|
|
t.Errorf("plaintext not returned: %+v", tok)
|
|
}
|
|
|
|
list, err := c.ListTokens(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListTokens: %v", err)
|
|
}
|
|
if len(list) != 1 || list[0].Name != "petalparse" {
|
|
t.Errorf("list = %+v", list)
|
|
}
|
|
|
|
if err := c.RevokeToken(context.Background(), "petalparse"); err != nil {
|
|
t.Fatalf("RevokeToken: %v", err)
|
|
}
|
|
|
|
if !created || !listed || !revoked {
|
|
t.Errorf("not all endpoints hit: c=%v l=%v r=%v", created, listed, revoked)
|
|
}
|
|
}
|
|
|
|
// L7: exercise the io.Pipe + goroutine + multipart writer interplay when
|
|
// the caller's context is cancelled mid-upload. Verifies that:
|
|
// 1. UploadReader returns a TransportError wrapping context.Canceled,
|
|
// 2. the producer goroutine and pipe close cleanly (no leak / no hang),
|
|
// 3. the source io.Reader stops being read once cancellation propagates.
|
|
func TestUploadReaderContextCancelMidStream(t *testing.T) {
|
|
// Server stalls so the upload is still in flight when we cancel.
|
|
handlerEntered := make(chan struct{})
|
|
handlerDone := make(chan struct{})
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
defer close(handlerDone)
|
|
close(handlerEntered)
|
|
// Read whatever the client sends until the request context dies or
|
|
// the client disconnects. Either way, just drain.
|
|
_, _ = io.Copy(io.Discard, r.Body)
|
|
}))
|
|
defer srv.Close()
|
|
c := New(srv.URL, "tok")
|
|
|
|
// A producer that emits chunks slowly so we're definitely mid-stream
|
|
// when cancel fires. It tracks calls so we can assert it stopped being
|
|
// read after cancellation.
|
|
slow := &slowReader{chunk: bytes.Repeat([]byte("x"), 256), delay: 5 * time.Millisecond, total: 1 << 20}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Run UploadReader on a goroutine so we can cancel from the outside
|
|
// once we know the request has reached the server.
|
|
type result struct {
|
|
err error
|
|
}
|
|
resCh := make(chan result, 1)
|
|
go func() {
|
|
_, err := c.UploadReader(ctx, "big.bin", slow, 60)
|
|
resCh <- result{err: err}
|
|
}()
|
|
|
|
// Wait until the server is actively handling the request, then cancel.
|
|
select {
|
|
case <-handlerEntered:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("server handler never entered — request did not reach server")
|
|
}
|
|
// Give the producer a moment to actually be writing into the pipe.
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
|
|
var got result
|
|
select {
|
|
case got = <-resCh:
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("UploadReader did not return after cancel — pipe/goroutine likely leaked")
|
|
}
|
|
|
|
if got.err == nil {
|
|
t.Fatal("expected error from cancelled upload")
|
|
}
|
|
var te *TransportError
|
|
if !errors.As(got.err, &te) {
|
|
t.Fatalf("err is not *TransportError: %T %v", got.err, got.err)
|
|
}
|
|
if !errors.Is(got.err, context.Canceled) {
|
|
t.Errorf("err does not wrap context.Canceled: %v", got.err)
|
|
}
|
|
|
|
// Snapshot reads-so-far, wait, then re-check: if cleanup worked the
|
|
// producer should have stopped being polled for more data.
|
|
before := slow.reads()
|
|
time.Sleep(150 * time.Millisecond)
|
|
after := slow.reads()
|
|
if after > before {
|
|
t.Errorf("producer kept being read after cancel: before=%d after=%d (pipe likely leaked)", before, after)
|
|
}
|
|
if before == 0 {
|
|
t.Error("producer was never read — test did not exercise multipart streaming")
|
|
}
|
|
|
|
// Force the stalled connection closed so the server handler unblocks.
|
|
srv.CloseClientConnections()
|
|
select {
|
|
case <-handlerDone:
|
|
case <-time.After(3 * time.Second):
|
|
t.Error("server handler did not exit")
|
|
}
|
|
}
|
|
|
|
// slowReader emits chunk after delay each Read, up to total bytes. It
|
|
// counts Read calls so tests can assert reads ceased after cancellation.
|
|
type slowReader struct {
|
|
chunk []byte
|
|
delay time.Duration
|
|
total int
|
|
mu sync.Mutex
|
|
produced int
|
|
calls int
|
|
}
|
|
|
|
func (s *slowReader) Read(p []byte) (int, error) {
|
|
time.Sleep(s.delay)
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.calls++
|
|
if s.produced >= s.total {
|
|
return 0, io.EOF
|
|
}
|
|
n := copy(p, s.chunk)
|
|
if s.produced+n > s.total {
|
|
n = s.total - s.produced
|
|
}
|
|
s.produced += n
|
|
return n, nil
|
|
}
|
|
|
|
func (s *slowReader) reads() int {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
return s.calls
|
|
}
|
|
|
|
func TestContextCancellation(t *testing.T) {
|
|
// Block server-side handler until the request context dies OR a hard
|
|
// safety timer fires (so a misbehaving client can't hang the test).
|
|
handlerDone := make(chan struct{})
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case <-r.Context().Done():
|
|
case <-time.After(2 * time.Second):
|
|
}
|
|
close(handlerDone)
|
|
}))
|
|
defer srv.Close()
|
|
c := New(srv.URL, "tok")
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err := c.Run(ctx, RunRequest{Prompt: "x"})
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var te *TransportError
|
|
if !errors.As(err, &te) {
|
|
t.Fatalf("err is not *TransportError: %T %v", err, err)
|
|
}
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
t.Errorf("err does not wrap DeadlineExceeded: %v", err)
|
|
}
|
|
|
|
// Force the abandoned connection closed so the handler's r.Context()
|
|
// fires immediately rather than waiting for the safety timer.
|
|
srv.CloseClientConnections()
|
|
|
|
select {
|
|
case <-handlerDone:
|
|
case <-time.After(3 * time.Second):
|
|
t.Error("server handler did not exit")
|
|
}
|
|
}
|
|
|
|
func TestBaseURLTrailingSlashTrimmed(t *testing.T) {
|
|
c := New("http://example.com:8800/", "tok")
|
|
if c.BaseURL != "http://example.com:8800" {
|
|
t.Errorf("BaseURL = %q", c.BaseURL)
|
|
}
|
|
}
|
|
|
|
func TestNewWithClientNilFallback(t *testing.T) {
|
|
c := NewWithClient("http://x", "tok", nil)
|
|
if c.HTTPClient == nil {
|
|
t.Fatal("HTTPClient should be set")
|
|
}
|
|
}
|
|
|
|
// Sanity check: multipart payload is constructed correctly even when ttlSecs=0.
|
|
func TestUploadReaderDefaultTTL(t *testing.T) {
|
|
c, done := newTestClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
mr, _ := r.MultipartReader()
|
|
sawTTL := false
|
|
for {
|
|
part, err := mr.NextPart()
|
|
if err != nil {
|
|
break
|
|
}
|
|
if part.FormName() == "ttl_secs" {
|
|
sawTTL = true
|
|
}
|
|
}
|
|
if sawTTL {
|
|
t.Error("ttl_secs should be omitted when 0")
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"file_token":"ff_x","ttl_secs":3600,"size":3}`))
|
|
})
|
|
defer done()
|
|
|
|
if _, err := c.UploadReader(context.Background(), "a.bin", strings.NewReader("abc"), 0); err != nil {
|
|
t.Fatalf("UploadReader: %v", err)
|
|
}
|
|
}
|
|
|
|
// Reference unused symbol so multipart import in test file is exercised.
|
|
var _ = multipart.NewWriter
|