From ba189340e54567096f8efb255024b3612995d194 Mon Sep 17 00:00:00 2001 From: Kayos Date: Thu, 14 May 2026 08:03:38 -0700 Subject: [PATCH] port the subprocess transport MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ports _internal/transport/subprocess_cli.py. Spawns the claude CLI with --output-format stream-json --input-format stream-json --verbose, then exposes a split (TransportReader, TransportWriter, TransportHandle) trio. The split is the key difference from the Python single-class transport: the reader half owns stdout exclusively and is moved into a background task; the writer half is Arc/Mutex over stdin and clones freely. A single mutex over the whole transport would deadlock the moment the reader blocked on stdin — which it does after each turn. Other notes: - find_cli() mirrors the Python search path (PATH, then ~/.npm-global/bin, /usr/local/bin, ~/.local/bin, ~/node_modules/.bin, ~/.yarn/bin, ~/.claude/local/claude). - build_command() faithfully ports _build_command() with the v0.1 option subset. - Env handling matches Python: filter CLAUDECODE on inherit, set CLAUDE_CODE_ENTRYPOINT=sdk-rust, layer user env, stamp CLAUDE_AGENT_SDK_VERSION last. - Stdout JSON parsing speculatively accumulates until serde_json succeeds or max_buffer_size (1 MiB default) overflows — same buffer-and-retry loop as the Python TextReceiveStream path. Non-JSON chatter ([SandboxDebug] etc.) is skipped between frames. - TransportHandle::close() gives the subprocess a 5s graceful shutdown window after stdin EOF before SIGKILL, mirroring the #625 fix in the Python SDK. - Drop on TransportHandle starts a best-effort kill so abandoned clients do not leak claude processes. Unit tests cover the JSON accumulator (full + partial + complete, non-JSON skip, overflow, multiline split) and the version parser. --- src/transport.rs | 703 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 703 insertions(+) create mode 100644 src/transport.rs diff --git a/src/transport.rs b/src/transport.rs new file mode 100644 index 0000000..1afa235 --- /dev/null +++ b/src/transport.rs @@ -0,0 +1,703 @@ +//! Subprocess transport over the Claude CLI's `--output-format stream-json` +//! interface. +//! +//! The transport spawns `claude` with `--input-format stream-json +//! --output-format stream-json --verbose`, writes newline-delimited JSON +//! frames on stdin, and reads newline-delimited JSON frames from stdout. Each +//! stdout frame is buffered and speculatively parsed — `TextReceiveStream` in +//! Python and `BufReader::lines` in Rust can both split a single JSON object +//! across multiple `lines()` ticks under load — so we keep accumulating until +//! `serde_json::from_str` succeeds or the buffer overflows the configured +//! cap. +//! +//! Internally the transport splits cleanly into two halves once connected: +//! [`TransportReader`] owns stdout and is moved into a background task; +//! [`TransportWriter`] owns stdin (behind a `Mutex`) and is shared across +//! callers via `Arc`. This split is what lets [`crate::Client::send`] and the +//! reader task make progress concurrently — a single `Mutex` over the whole +//! transport would deadlock as soon as the reader blocked on stdin. +//! +//! This module is a port of `_internal/transport/subprocess_cli.py` from the +//! Python SDK, simplified for the v0.1 surface (no control protocol yet — +//! initialize/interrupt/can_use_tool are deferred to v0.2). + +use std::collections::HashMap; +use std::env; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; + +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; +use tokio::sync::Mutex; + +use crate::errors::{Error, Result}; +use crate::options::{ClaudeAgentOptions, McpServersConfig, SystemPrompt}; + +/// Default 1 MiB cap on the stdout JSON-accumulation buffer. +const DEFAULT_MAX_BUFFER_SIZE: usize = 1024 * 1024; + +/// Minimum CLI version this SDK targets. +const MINIMUM_CLAUDE_CODE_VERSION: (u32, u32, u32) = (2, 0, 0); + +/// Subprocess transport for the `claude` CLI. +/// +/// Use [`SubprocessTransport::new`] to construct, then [`connect`] to spawn +/// the subprocess and obtain a `(reader, writer, handle)` split: +/// +/// ```ignore +/// let mut t = SubprocessTransport::new(opts)?; +/// let (reader, writer, handle) = t.connect().await?; +/// ``` +/// +/// The reader half is consumed by the message-pump task; the writer half is +/// cloned freely (`Arc>`-backed) and used to push user-message +/// frames. The handle is what you call [`close`] on to wait for the +/// subprocess to exit. +/// +/// Drop implements best-effort `start_kill()` on the child so an abandoned +/// transport doesn't leak a `claude` process. +/// +/// [`connect`]: SubprocessTransport::connect +/// [`close`]: TransportHandle::close +pub struct SubprocessTransport { + options: ClaudeAgentOptions, + cli_path: PathBuf, +} + +/// The stdout half of a connected transport. Moved into a dedicated reader +/// task by [`crate::Client`] / [`crate::query`]. +pub struct TransportReader { + stdout: BufReader, + buffer: String, + json_buffer: String, + max_buffer_size: usize, +} + +/// The stdin half of a connected transport. Cloneable handle that callers can +/// use to write newline-delimited JSON frames concurrently with the reader. +#[derive(Clone)] +pub struct TransportWriter { + inner: Arc, +} + +struct TransportWriterInner { + stdin: Mutex>, +} + +/// Owner of the child process. Calls [`TransportHandle::close`] to terminate +/// stdin, await the subprocess, and surface non-zero exit codes. +pub struct TransportHandle { + child: Option, + stderr_capture: Arc>, + writer: TransportWriter, +} + +impl SubprocessTransport { + /// Build a new transport from options. Resolves the CLI binary path eagerly + /// — pass an explicit [`ClaudeAgentOptions::cli_path`] to skip the search. + pub fn new(options: ClaudeAgentOptions) -> Result { + let cli_path = match options.cli_path.clone() { + Some(p) => p, + None => find_cli()?, + }; + Ok(Self { options, cli_path }) + } + + /// Resolved CLI binary path used to spawn the subprocess. + pub fn cli_path(&self) -> &Path { + &self.cli_path + } + + /// Spawn the subprocess and return reader / writer / handle. + pub async fn connect(self) -> Result<(TransportReader, TransportWriter, TransportHandle)> { + if !self.options.skip_version_check + && env::var("CLAUDE_AGENT_SDK_SKIP_VERSION_CHECK").is_err() + { + check_claude_version(&self.cli_path).await; + } + + let cmd_args = build_command(&self.options); + tracing::debug!(?cmd_args, "spawning claude CLI"); + + let mut command = Command::new(&self.cli_path); + command.args(&cmd_args); + command.stdin(Stdio::piped()); + command.stdout(Stdio::piped()); + if self.options.capture_stderr { + command.stderr(Stdio::piped()); + } else { + command.stderr(Stdio::inherit()); + } + + // Env handling matches the Python SDK: filter CLAUDECODE, set the + // entrypoint marker, then apply user-provided env (so user overrides + // win), then stamp our SDK version. + let inherited: HashMap = env::vars() + .filter(|(k, _)| k != "CLAUDECODE") + .collect(); + command.env_clear(); + for (k, v) in inherited { + command.env(k, v); + } + command.env("CLAUDE_CODE_ENTRYPOINT", "sdk-rust"); + for (k, v) in &self.options.env { + command.env(k, v); + } + command.env("CLAUDE_AGENT_SDK_VERSION", env!("CARGO_PKG_VERSION")); + + if let Some(cwd) = &self.options.cwd { + command.current_dir(cwd); + command.env("PWD", cwd.to_string_lossy().to_string()); + } + + let mut child = command.spawn().map_err(|e| { + if let Some(cwd) = &self.options.cwd { + if !cwd.exists() { + return Error::conn(format!( + "Working directory does not exist: {}", + cwd.display() + )); + } + } + if e.kind() == std::io::ErrorKind::NotFound { + Error::CliNotFound(format!( + "Claude Code not found at: {}", + self.cli_path.display() + )) + } else { + Error::conn(format!("Failed to start Claude Code: {e}")) + } + })?; + + let stdin = child + .stdin + .take() + .ok_or_else(|| Error::conn("stdin not piped"))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| Error::conn("stdout not piped"))?; + + let stderr_capture = Arc::new(Mutex::new(String::new())); + if self.options.capture_stderr { + let stderr = child + .stderr + .take() + .ok_or_else(|| Error::conn("stderr not piped"))?; + let sink = stderr_capture.clone(); + tokio::spawn(async move { + let mut reader = BufReader::new(stderr); + let mut line = String::new(); + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) | Err(_) => break, + Ok(_) => { + let mut s = sink.lock().await; + s.push_str(&line); + // Cap captured buffer at 64 KiB to bound memory. + if s.len() > 64 * 1024 { + let drop = s.len() - 64 * 1024; + s.drain(..drop); + } + } + } + } + }); + } + + let max_buffer_size = self.options.max_buffer_size.unwrap_or(DEFAULT_MAX_BUFFER_SIZE); + + let writer = TransportWriter { + inner: Arc::new(TransportWriterInner { + stdin: Mutex::new(Some(stdin)), + }), + }; + let reader = TransportReader { + stdout: BufReader::new(stdout), + buffer: String::new(), + json_buffer: String::new(), + max_buffer_size, + }; + let handle = TransportHandle { + child: Some(child), + stderr_capture, + writer: writer.clone(), + }; + + Ok((reader, writer, handle)) + } +} + +impl TransportReader { + /// Read the next JSON frame from stdout, accumulating partial lines into + /// a buffer until the buffer parses as JSON or overflows + /// `max_buffer_size`. + /// + /// Returns `Ok(None)` on clean EOF. + pub async fn read_frame(&mut self) -> Result> { + loop { + // Drain any complete JSON already buffered from a prior read. + if let Some(value) = + try_drain_buffer(&mut self.buffer, &mut self.json_buffer, self.max_buffer_size)? + { + return Ok(Some(value)); + } + + let mut line = String::new(); + let n = self + .stdout + .read_line(&mut line) + .await + .map_err(|e| Error::conn(format!("stdout read failed: {e}")))?; + if n == 0 { + // EOF — flush any trailing partial JSON as a last attempt. + if !self.json_buffer.trim().is_empty() { + let attempt = std::mem::take(&mut self.json_buffer); + return serde_json::from_str::(attempt.trim()) + .map(Some) + .map_err(|e| Error::JsonDecode { + line_preview: preview(&attempt), + source: e, + }); + } + return Ok(None); + } + self.buffer.push_str(&line); + } + } +} + +impl TransportWriter { + /// Write one newline-delimited JSON frame to stdin. Appends `'\n'` if the + /// caller didn't include one, and flushes. + pub async fn write_frame(&self, data: &str) -> Result<()> { + let mut guard = self.inner.stdin.lock().await; + let stdin = guard + .as_mut() + .ok_or_else(|| Error::conn("transport stdin closed"))?; + stdin + .write_all(data.as_bytes()) + .await + .map_err(|e| Error::conn(format!("Failed to write stdin: {e}")))?; + if !data.ends_with('\n') { + stdin + .write_all(b"\n") + .await + .map_err(|e| Error::conn(format!("Failed to write stdin: {e}")))?; + } + stdin + .flush() + .await + .map_err(|e| Error::conn(format!("Failed to flush stdin: {e}")))?; + Ok(()) + } + + /// Close stdin so the subprocess can begin shutting down. Idempotent. + pub async fn end_input(&self) { + let mut guard = self.inner.stdin.lock().await; + guard.take(); // drop ChildStdin → close write end + } + + /// True if stdin has already been closed via [`end_input`] or the + /// underlying handle was dropped. + pub async fn is_closed(&self) -> bool { + self.inner.stdin.lock().await.is_none() + } +} + +impl TransportHandle { + /// Close stdin, wait for the subprocess to exit, and surface non-zero + /// exit codes as [`Error::Process`]. Idempotent. + pub async fn close(mut self) -> Result<()> { + // Drop stdin first (mirror Python SDK: graceful EOF). + self.writer.end_input().await; + + let Some(mut child) = self.child.take() else { + return Ok(()); + }; + + // 5s graceful shutdown window. + let status = match tokio::time::timeout(Duration::from_secs(5), child.wait()).await { + Ok(s) => s.map_err(Error::from)?, + Err(_) => { + let _ = child.start_kill(); + child + .wait() + .await + .map_err(|e| Error::conn(format!("subprocess wait failed: {e}")))? + } + }; + + if !status.success() { + let captured = self.stderr_capture.lock().await.clone(); + return Err(Error::Process { + message: "Command failed".into(), + exit_code: status.code(), + stderr: if captured.is_empty() { + None + } else { + Some(captured) + }, + }); + } + Ok(()) + } + + /// Best-effort kill, used by [`Drop`]. + fn kill(&mut self) { + if let Some(child) = self.child.as_mut() { + let _ = child.start_kill(); + } + } +} + +impl Drop for TransportHandle { + fn drop(&mut self) { + self.kill(); + } +} + +async fn check_claude_version(cli_path: &Path) { + let res = tokio::time::timeout( + Duration::from_secs(2), + Command::new(cli_path).arg("-v").output(), + ) + .await; + let Ok(Ok(output)) = res else { return }; + let stdout = String::from_utf8_lossy(&output.stdout); + let re_match = stdout + .split_whitespace() + .find_map(parse_version); + if let Some((maj, min, patch)) = re_match { + if (maj, min, patch) < MINIMUM_CLAUDE_CODE_VERSION { + tracing::warn!( + "Claude Code version {maj}.{min}.{patch} at {} is below the minimum supported {}.{}.{}", + cli_path.display(), + MINIMUM_CLAUDE_CODE_VERSION.0, + MINIMUM_CLAUDE_CODE_VERSION.1, + MINIMUM_CLAUDE_CODE_VERSION.2, + ); + } + } +} + +/// Build the full argv for spawning `claude`. Mirrors `_build_command()` in +/// the Python SDK. +fn build_command(options: &ClaudeAgentOptions) -> Vec { + let mut cmd: Vec = vec![ + "--output-format".into(), + "stream-json".into(), + "--verbose".into(), + ]; + + match &options.system_prompt { + None => { + cmd.push("--system-prompt".into()); + cmd.push(String::new()); + } + Some(SystemPrompt::String(s)) => { + cmd.push("--system-prompt".into()); + cmd.push(s.clone()); + } + Some(SystemPrompt::PresetAppend(s)) => { + cmd.push("--append-system-prompt".into()); + cmd.push(s.clone()); + } + Some(SystemPrompt::File(p)) => { + cmd.push("--system-prompt-file".into()); + cmd.push(p.to_string_lossy().into_owned()); + } + } + + if let Some(tools) = &options.tools { + cmd.push("--tools".into()); + cmd.push(if tools.is_empty() { + String::new() + } else { + tools.join(",") + }); + } + + if !options.allowed_tools.is_empty() { + cmd.push("--allowedTools".into()); + cmd.push(options.allowed_tools.join(",")); + } + + if let Some(n) = options.max_turns { + cmd.push("--max-turns".into()); + cmd.push(n.to_string()); + } + + if let Some(usd) = options.max_budget_usd { + cmd.push("--max-budget-usd".into()); + cmd.push(usd.to_string()); + } + + if !options.disallowed_tools.is_empty() { + cmd.push("--disallowedTools".into()); + cmd.push(options.disallowed_tools.join(",")); + } + + if let Some(model) = &options.model { + cmd.push("--model".into()); + cmd.push(model.clone()); + } + + if let Some(model) = &options.fallback_model { + cmd.push("--fallback-model".into()); + cmd.push(model.clone()); + } + + if let Some(mode) = options.permission_mode { + cmd.push("--permission-mode".into()); + cmd.push(mode.as_cli_str().into()); + } + + if options.continue_conversation { + cmd.push("--continue".into()); + } + + if let Some(rid) = &options.resume { + cmd.push("--resume".into()); + cmd.push(rid.clone()); + } + + if let Some(sid) = &options.session_id { + cmd.push("--session-id".into()); + cmd.push(sid.clone()); + } + + if let Some(settings) = &options.settings { + cmd.push("--settings".into()); + cmd.push(settings.clone()); + } + + for dir in &options.add_dirs { + cmd.push("--add-dir".into()); + cmd.push(dir.to_string_lossy().into_owned()); + } + + match &options.mcp_servers { + Some(McpServersConfig::Inline(v)) => { + cmd.push("--mcp-config".into()); + cmd.push(v.to_string()); + } + Some(McpServersConfig::Path(p)) => { + cmd.push("--mcp-config".into()); + cmd.push(p.to_string_lossy().into_owned()); + } + None => {} + } + + if options.include_partial_messages { + cmd.push("--include-partial-messages".into()); + } + if options.include_hook_events { + cmd.push("--include-hook-events".into()); + } + if options.fork_session { + cmd.push("--fork-session".into()); + } + + for (flag, value) in &options.extra_args { + match value { + None => cmd.push(format!("--{flag}")), + Some(v) => { + cmd.push(format!("--{flag}")); + cmd.push(v.clone()); + } + } + } + + if let Some(effort) = options.effort { + cmd.push("--effort".into()); + cmd.push(effort.as_cli_str().into()); + } + + if let Some(of) = &options.output_format { + if of.get("type").and_then(|v| v.as_str()) == Some("json_schema") { + if let Some(schema) = of.get("schema") { + cmd.push("--json-schema".into()); + cmd.push(schema.to_string()); + } + } + } + + cmd.push("--input-format".into()); + cmd.push("stream-json".into()); + + cmd +} + +/// Search `PATH` and a small set of standard install locations for the +/// `claude` CLI binary. Mirrors `_find_cli()` in the Python SDK. +fn find_cli() -> Result { + if let Some(p) = which("claude") { + return Ok(p); + } + let home = env::var("HOME").ok().map(PathBuf::from); + let candidates: Vec = vec![ + home.as_ref().map(|h| h.join(".npm-global/bin/claude")), + Some(PathBuf::from("/usr/local/bin/claude")), + home.as_ref().map(|h| h.join(".local/bin/claude")), + home.as_ref().map(|h| h.join("node_modules/.bin/claude")), + home.as_ref().map(|h| h.join(".yarn/bin/claude")), + home.as_ref().map(|h| h.join(".claude/local/claude")), + ] + .into_iter() + .flatten() + .collect(); + for path in candidates { + if path.is_file() { + return Ok(path); + } + } + Err(Error::CliNotFound( + "Claude Code not found. Install with:\n npm install -g @anthropic-ai/claude-code\n\n\ + Or specify the path via ClaudeAgentOptions::with_cli_path()." + .into(), + )) +} + +fn which(binary: &str) -> Option { + let path = env::var_os("PATH")?; + for dir in env::split_paths(&path) { + let candidate = dir.join(binary); + if candidate.is_file() { + return Some(candidate); + } + if cfg!(windows) { + let with_ext = dir.join(format!("{binary}.exe")); + if with_ext.is_file() { + return Some(with_ext); + } + } + } + None +} + +/// Pull complete JSON objects out of `buffer` into `json_buffer`, parsing +/// speculatively. +fn try_drain_buffer( + buffer: &mut String, + json_buffer: &mut String, + max_buffer_size: usize, +) -> Result> { + loop { + let nl_pos = match buffer.find('\n') { + Some(p) => p, + None => return Ok(None), + }; + let line: String = buffer.drain(..=nl_pos).collect(); + let line_str = line.trim(); + if line_str.is_empty() { + continue; + } + if json_buffer.is_empty() && !line_str.starts_with('{') { + tracing::debug!("skipping non-JSON CLI stdout line: {}", preview(line_str)); + continue; + } + json_buffer.push_str(line_str); + if json_buffer.len() > max_buffer_size { + let len = json_buffer.len(); + let preview = preview(json_buffer); + json_buffer.clear(); + return Err(Error::JsonDecode { + line_preview: format!( + "JSON message exceeded maximum buffer size of {max_buffer_size} bytes \ + (buffer was {len} bytes; first bytes: {preview})" + ), + source: serde_json::from_str::("").unwrap_err(), + }); + } + match serde_json::from_str::(json_buffer) { + Ok(value) => { + json_buffer.clear(); + return Ok(Some(value)); + } + Err(e) => { + if e.is_eof() { + continue; + } + let preview = preview(json_buffer); + json_buffer.clear(); + return Err(Error::JsonDecode { + line_preview: preview, + source: e, + }); + } + } + } +} + +fn preview(s: &str) -> String { + s.chars().take(100).collect() +} + +fn parse_version(token: &str) -> Option<(u32, u32, u32)> { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() < 3 { + return None; + } + let major: u32 = parts[0].parse().ok()?; + let minor: u32 = parts[1].parse().ok()?; + let patch_str: String = parts[2].chars().take_while(|c| c.is_ascii_digit()).collect(); + let patch: u32 = patch_str.parse().ok()?; + Some((major, minor, patch)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn drain_full_then_partial_then_complete() { + let mut buf = String::from("{\"a\":1}\n{\"b\":"); + let mut jb = String::new(); + let v = try_drain_buffer(&mut buf, &mut jb, 1024).unwrap().unwrap(); + assert_eq!(v["a"], 1); + let next = try_drain_buffer(&mut buf, &mut jb, 1024).unwrap(); + assert!(next.is_none()); + buf.push_str("2}\n"); + let v = try_drain_buffer(&mut buf, &mut jb, 1024).unwrap().unwrap(); + assert_eq!(v["b"], 2); + } + + #[test] + fn drain_skips_non_json_chatter() { + let mut buf = String::from("[SandboxDebug] starting\n{\"ok\":true}\n"); + let mut jb = String::new(); + let v = try_drain_buffer(&mut buf, &mut jb, 1024).unwrap().unwrap(); + assert_eq!(v["ok"], true); + } + + #[test] + fn drain_overflow_errors() { + let mut buf = format!("{{\"x\":\"{}\"}}\n", "a".repeat(100)); + let mut jb = String::new(); + let res = try_drain_buffer(&mut buf, &mut jb, 10); + assert!(matches!(res, Err(Error::JsonDecode { .. }))); + } + + #[test] + fn drain_handles_multiline_split() { + let mut buf = String::from("{\"key\":\n"); + let mut jb = String::new(); + let v = try_drain_buffer(&mut buf, &mut jb, 1024).unwrap(); + assert!(v.is_none()); + buf.push_str("\"v\"}\n"); + let v = try_drain_buffer(&mut buf, &mut jb, 1024).unwrap().unwrap(); + assert_eq!(v["key"], "v"); + } + + #[test] + fn parse_version_basic() { + assert_eq!(parse_version("2.0.0"), Some((2, 0, 0))); + assert_eq!(parse_version("2.1.110"), Some((2, 1, 110))); + assert_eq!(parse_version("1.9.0-beta"), Some((1, 9, 0))); + assert_eq!(parse_version("abc"), None); + } +}