- Session struct with idempotent Close(ctx) (atomic.Bool short-circuit)
- Client.NewSession(ctx, opts) / ListSessions(ctx) / GetSession(ctx, id)
- TurnResult.Text() helper concatenates text events
- Per-session sync.Mutex serializes concurrent Turn calls
- clawdforge_session_test.go: 9 tests
- README "Multi-turn / Sessions (v0.2)" section
v0.1 Run path unchanged.
Spec: memory/spec-clawdforge-v0.2.md
Server core: 940861f
424 lines
13 KiB
Go
424 lines
13 KiB
Go
package clawdforge
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// ---------- v0.2 Session tests ----------------------------------------------
|
|
|
|
// TestNewSessionAndClose exercises the create + close round-trip end to end:
|
|
// POST /sessions returns a handle, defer Close hits DELETE /sessions/{id},
|
|
// and the test asserts both endpoints actually got hit.
|
|
func TestNewSessionAndClose(t *testing.T) {
|
|
var (
|
|
gotCreate atomic.Bool
|
|
gotClose atomic.Bool
|
|
)
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch {
|
|
case r.Method == http.MethodPost && r.URL.Path == "/sessions":
|
|
gotCreate.Store(true)
|
|
if got := r.Header.Get("Authorization"); got != "Bearer cf_test_token" {
|
|
t.Errorf("Authorization = %q", got)
|
|
}
|
|
var body struct {
|
|
Agent string `json:"agent"`
|
|
Meta map[string]any `json:"meta"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("decode create body: %v", err)
|
|
}
|
|
if body.Agent != "claude" {
|
|
t.Errorf("agent = %q, want claude", body.Agent)
|
|
}
|
|
_, _ = w.Write([]byte(`{"ok":true,"session_id":"sess_abc","agent":"claude","created_at":1700000000}`))
|
|
case r.Method == http.MethodDelete && r.URL.Path == "/sessions/sess_abc":
|
|
gotClose.Store(true)
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
default:
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "cf_test_token")
|
|
s, err := c.NewSession(context.Background(), &SessionOptions{Agent: "claude"})
|
|
if err != nil {
|
|
t.Fatalf("NewSession: %v", err)
|
|
}
|
|
if s.ID() != "sess_abc" || s.Agent() != "claude" || s.CreatedAt() != 1700000000 {
|
|
t.Errorf("session getters wrong: id=%q agent=%q createdAt=%d", s.ID(), s.Agent(), s.CreatedAt())
|
|
}
|
|
|
|
if err := s.Close(context.Background()); err != nil {
|
|
t.Fatalf("Close: %v", err)
|
|
}
|
|
if !gotCreate.Load() {
|
|
t.Error("POST /sessions never hit")
|
|
}
|
|
if !gotClose.Load() {
|
|
t.Error("DELETE /sessions/sess_abc never hit")
|
|
}
|
|
}
|
|
|
|
// TestSessionTurn round-trips a turn through a mocked /turn endpoint and
|
|
// verifies request body shape + response decoding (events, indices, timing).
|
|
func TestSessionTurn(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch {
|
|
case r.Method == http.MethodPost && r.URL.Path == "/sessions":
|
|
_, _ = w.Write([]byte(`{"ok":true,"session_id":"sess_t","agent":"claude","created_at":1}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/sessions/sess_t/turn":
|
|
var body turnRequestBody
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
t.Fatalf("decode turn body: %v", err)
|
|
}
|
|
if body.Prompt != "summarize README" {
|
|
t.Errorf("prompt = %q", body.Prompt)
|
|
}
|
|
if len(body.Files) != 1 || body.Files[0] != "ff_xyz" {
|
|
t.Errorf("files = %v", body.Files)
|
|
}
|
|
// 1500ms → 2 secs (round up)
|
|
if body.TimeoutSecs != 2 {
|
|
t.Errorf("timeout_secs = %d, want 2 (1500ms rounds up)", body.TimeoutSecs)
|
|
}
|
|
_, _ = w.Write([]byte(`{
|
|
"ok": true,
|
|
"session_id": "sess_t",
|
|
"turn_index": 1,
|
|
"events": [
|
|
{"type":"thinking","content":"reading..."},
|
|
{"type":"tool_call","name":"Read","args":{"path":"README.md"},"result":{"len":42}},
|
|
{"type":"text","content":"Hello "},
|
|
{"type":"text","content":"world"}
|
|
],
|
|
"stop_reason": "end_turn",
|
|
"duration_ms": 1234
|
|
}`))
|
|
default:
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok")
|
|
s, err := c.NewSession(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewSession: %v", err)
|
|
}
|
|
res, err := s.Turn(context.Background(), "summarize README", TurnOption{
|
|
Files: []string{"ff_xyz"},
|
|
TimeoutMs: 1500,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("Turn: %v", err)
|
|
}
|
|
if !res.Ok || res.TurnIndex != 1 || res.StopReason != "end_turn" || res.DurationMs != 1234 {
|
|
t.Errorf("got %+v", res)
|
|
}
|
|
if len(res.Events) != 4 {
|
|
t.Fatalf("events len = %d, want 4", len(res.Events))
|
|
}
|
|
if res.Events[1].Type != "tool_call" || res.Events[1].Name != "Read" {
|
|
t.Errorf("tool_call event = %+v", res.Events[1])
|
|
}
|
|
}
|
|
|
|
// TestSessionCloseIdempotent verifies the second Close short-circuits via
|
|
// the atomic flag and never hits the network.
|
|
func TestSessionCloseIdempotent(t *testing.T) {
|
|
var closeHits atomic.Int32
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch {
|
|
case r.Method == http.MethodPost && r.URL.Path == "/sessions":
|
|
_, _ = w.Write([]byte(`{"ok":true,"session_id":"sess_idem","agent":"claude","created_at":1}`))
|
|
case r.Method == http.MethodDelete && r.URL.Path == "/sessions/sess_idem":
|
|
closeHits.Add(1)
|
|
_, _ = w.Write([]byte(`{"ok":true}`))
|
|
default:
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok")
|
|
s, err := c.NewSession(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewSession: %v", err)
|
|
}
|
|
if err := s.Close(context.Background()); err != nil {
|
|
t.Fatalf("Close 1: %v", err)
|
|
}
|
|
if err := s.Close(context.Background()); err != nil {
|
|
t.Fatalf("Close 2: %v", err)
|
|
}
|
|
if err := s.Close(context.Background()); err != nil {
|
|
t.Fatalf("Close 3: %v", err)
|
|
}
|
|
if got := closeHits.Load(); got != 1 {
|
|
t.Errorf("DELETE hit %d times, want exactly 1 (subsequent calls must short-circuit)", got)
|
|
}
|
|
}
|
|
|
|
// TestSessionConcurrentTurns dispatches two Turn calls on the same Session
|
|
// from two goroutines and verifies the server never sees them overlap —
|
|
// the per-session mutex must serialize them. The mock asserts no overlap by
|
|
// counting in-flight turns; if the lock leaked, count would exceed 1.
|
|
func TestSessionConcurrentTurns(t *testing.T) {
|
|
var (
|
|
inflight atomic.Int32
|
|
maxSeen atomic.Int32
|
|
hitCount atomic.Int32
|
|
)
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
switch {
|
|
case r.Method == http.MethodPost && r.URL.Path == "/sessions":
|
|
_, _ = w.Write([]byte(`{"ok":true,"session_id":"sess_conc","agent":"claude","created_at":1}`))
|
|
case r.Method == http.MethodPost && r.URL.Path == "/sessions/sess_conc/turn":
|
|
cur := inflight.Add(1)
|
|
defer inflight.Add(-1)
|
|
// Track high-water mark of concurrent turns on this session.
|
|
for {
|
|
m := maxSeen.Load()
|
|
if cur <= m || maxSeen.CompareAndSwap(m, cur) {
|
|
break
|
|
}
|
|
}
|
|
// Sleep so a leaky lock would let goroutine 2 enter while
|
|
// goroutine 1 is still in-flight.
|
|
time.Sleep(80 * time.Millisecond)
|
|
hitCount.Add(1)
|
|
_, _ = w.Write([]byte(`{"ok":true,"session_id":"sess_conc","turn_index":` +
|
|
strconvItoa(int(hitCount.Load())) +
|
|
`,"events":[],"stop_reason":"end_turn","duration_ms":1}`))
|
|
default:
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok")
|
|
s, err := c.NewSession(context.Background(), nil)
|
|
if err != nil {
|
|
t.Fatalf("NewSession: %v", err)
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
errs := make([]error, 2)
|
|
go func() {
|
|
defer wg.Done()
|
|
_, errs[0] = s.Turn(context.Background(), "first")
|
|
}()
|
|
go func() {
|
|
defer wg.Done()
|
|
_, errs[1] = s.Turn(context.Background(), "second")
|
|
}()
|
|
wg.Wait()
|
|
|
|
for i, e := range errs {
|
|
if e != nil {
|
|
t.Errorf("turn %d: %v", i, e)
|
|
}
|
|
}
|
|
if hitCount.Load() != 2 {
|
|
t.Errorf("hitCount = %d, want 2", hitCount.Load())
|
|
}
|
|
if maxSeen.Load() > 1 {
|
|
t.Errorf("max concurrent in-flight turns = %d, want 1 (per-session mutex must serialize)", maxSeen.Load())
|
|
}
|
|
}
|
|
|
|
// TestListSessions verifies GET /sessions decodes into []SessionState.
|
|
func TestListSessions(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet || r.URL.Path != "/sessions" {
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{
|
|
"ok": true,
|
|
"sessions": [
|
|
{"session_id":"sess_a","agent":"claude","app_name":"app1","created_at":100,"last_turn_at":150,"turn_count":3,"closed_at":null},
|
|
{"session_id":"sess_b","agent":"claude","app_name":"app1","created_at":200,"last_turn_at":null,"turn_count":0,"closed_at":250}
|
|
],
|
|
"count": 2
|
|
}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok")
|
|
list, err := c.ListSessions(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("ListSessions: %v", err)
|
|
}
|
|
if len(list) != 2 {
|
|
t.Fatalf("len = %d, want 2", len(list))
|
|
}
|
|
if list[0].SessionID != "sess_a" || list[0].TurnCount != 3 || list[0].ClosedAt != nil {
|
|
t.Errorf("sess_a wrong: %+v", list[0])
|
|
}
|
|
if list[1].LastTurnAt != nil {
|
|
t.Errorf("sess_b LastTurnAt should be nil, got %v", *list[1].LastTurnAt)
|
|
}
|
|
if list[1].ClosedAt == nil || *list[1].ClosedAt != 250 {
|
|
t.Errorf("sess_b ClosedAt wrong: %+v", list[1].ClosedAt)
|
|
}
|
|
}
|
|
|
|
// TestGetSession verifies GET /sessions/{id} decodes into a *SessionState.
|
|
func TestGetSession(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodGet || r.URL.Path != "/sessions/sess_get" {
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{
|
|
"ok": true,
|
|
"session_id": "sess_get",
|
|
"agent": "claude",
|
|
"created_at": 555,
|
|
"last_turn_at": 600,
|
|
"turn_count": 2,
|
|
"closed_at": null
|
|
}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok")
|
|
st, err := c.GetSession(context.Background(), "sess_get")
|
|
if err != nil {
|
|
t.Fatalf("GetSession: %v", err)
|
|
}
|
|
if st.SessionID != "sess_get" || st.TurnCount != 2 {
|
|
t.Errorf("got %+v", st)
|
|
}
|
|
if st.LastTurnAt == nil || *st.LastTurnAt != 600 {
|
|
t.Errorf("LastTurnAt wrong: %+v", st.LastTurnAt)
|
|
}
|
|
}
|
|
|
|
// TestSessionCrossTokenIs404 verifies cross-token (or unknown-id) access
|
|
// surfaces as *APIError with StatusCode==404 — same shape as the v0.1 generic
|
|
// 404 path. No new error types.
|
|
func TestSessionCrossTokenIs404(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusNotFound)
|
|
_, _ = w.Write([]byte(`{"detail":"session not found"}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok_b")
|
|
_, err := c.GetSession(context.Background(), "sess_belongs_to_a")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
var apiErr *APIError
|
|
if !errors.As(err, &apiErr) {
|
|
t.Fatalf("err is not *APIError: %T %v", err, err)
|
|
}
|
|
if apiErr.StatusCode != 404 {
|
|
t.Errorf("StatusCode = %d, want 404", apiErr.StatusCode)
|
|
}
|
|
if !strings.Contains(apiErr.Message, "not found") {
|
|
t.Errorf("Message = %q, want it to mention 'not found'", apiErr.Message)
|
|
}
|
|
}
|
|
|
|
// TestTurnResultText verifies the Text() helper concatenates only "text"
|
|
// events and skips thinking / tool_call frames.
|
|
func TestTurnResultText(t *testing.T) {
|
|
r := &TurnResult{
|
|
Events: []TurnEvent{
|
|
{Type: "thinking", Content: "hmm "},
|
|
{Type: "text", Content: "hello "},
|
|
{Type: "tool_call", Name: "Read"},
|
|
{Type: "text", Content: "world"},
|
|
{Type: "text", Content: "!"},
|
|
},
|
|
}
|
|
if got := r.Text(); got != "hello world!" {
|
|
t.Errorf("Text() = %q, want %q", got, "hello world!")
|
|
}
|
|
// Empty / nil safety.
|
|
if (&TurnResult{}).Text() != "" {
|
|
t.Error("empty TurnResult.Text() should be empty string")
|
|
}
|
|
var nilR *TurnResult
|
|
if nilR.Text() != "" {
|
|
t.Error("nil TurnResult.Text() should be empty string")
|
|
}
|
|
}
|
|
|
|
// TestRunUnchanged is the v0.1 regression — adding the Session surface must
|
|
// not perturb the byte-on-the-wire shape of POST /run or its result decoding.
|
|
func TestRunUnchanged(t *testing.T) {
|
|
var sawBody RunRequest
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/run" || r.Method != http.MethodPost {
|
|
t.Fatalf("unexpected %s %s", r.Method, r.URL.Path)
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&sawBody); err != nil {
|
|
t.Fatalf("decode: %v", err)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_, _ = w.Write([]byte(`{"ok":true,"result":"plain","duration_ms":7,"stop_reason":"end_turn"}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
c := New(srv.URL, "tok")
|
|
res, err := c.Run(context.Background(), RunRequest{Prompt: "hi", Model: "sonnet"})
|
|
if err != nil {
|
|
t.Fatalf("Run: %v", err)
|
|
}
|
|
if !res.OK || res.DurationMS != 7 {
|
|
t.Errorf("got %+v", res)
|
|
}
|
|
if sawBody.Prompt != "hi" || sawBody.Model != "sonnet" {
|
|
t.Errorf("server saw body %+v — v0.1 wire shape changed!", sawBody)
|
|
}
|
|
s, err := res.AsText()
|
|
if err != nil || s != "plain" {
|
|
t.Errorf("AsText = %q err=%v", s, err)
|
|
}
|
|
}
|
|
|
|
// strconvItoa is a tiny inline shim so the concurrent-turn test can build a
|
|
// fresh JSON body per call without dragging strconv into the test file's
|
|
// imports beyond what's needed elsewhere — keeps the test self-contained.
|
|
func strconvItoa(n int) string {
|
|
if n == 0 {
|
|
return "0"
|
|
}
|
|
var buf [20]byte
|
|
i := len(buf)
|
|
neg := n < 0
|
|
if neg {
|
|
n = -n
|
|
}
|
|
for n > 0 {
|
|
i--
|
|
buf[i] = byte('0' + n%10)
|
|
n /= 10
|
|
}
|
|
if neg {
|
|
i--
|
|
buf[i] = '-'
|
|
}
|
|
return string(buf[i:])
|
|
}
|